From 02d3b8680e65358f6af11a04839d1601354953a6 Mon Sep 17 00:00:00 2001 From: trangevi Date: Tue, 24 Mar 2026 15:59:32 -0700 Subject: [PATCH 01/14] Maybe working changes for tools Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/init.go | 63 +++++ .../cmd/init_foundry_resources_helpers.go | 20 +- .../init_foundry_resources_helpers_test.go | 43 ++++ .../internal/exterrors/codes.go | 10 + .../internal/pkg/agents/agent_yaml/parse.go | 6 + .../pkg/agents/agent_yaml/parse_test.go | 116 +++++++++ .../internal/pkg/agents/agent_yaml/yaml.go | 13 +- .../pkg/azure/foundry_toolsets_client.go | 243 ++++++++++++++++++ .../internal/project/config.go | 8 + .../internal/project/config_test.go | 205 +++++++++++++++ .../internal/project/service_target_agent.go | 112 ++++++++ .../schemas/azure.ai.agent.json | 20 ++ 12 files changed, 854 insertions(+), 5 deletions(-) create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/project/config_test.go diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index e4fbf460ae2..0334c657103 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -1163,6 +1163,13 @@ func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentMa agentConfig.Deployments = a.deploymentDetails agentConfig.Resources = resourceDetails + // Process toolbox resources from the manifest + toolboxes, err := extractToolboxConfigs(agentManifest) + if err != nil { + return err + } + agentConfig.Toolboxes = toolboxes + // Detect startup command from the project source directory startupCmd, err := resolveStartupCommandForInit(ctx, a.azdClient, a.projectConfig.Path, targetDir, a.flags.NoPrompt) if err != nil { @@ -1566,3 +1573,59 @@ func downloadDirectoryContentsWithoutGhCli( return nil } + +// extractToolboxConfigs extracts toolbox resource definitions from the agent manifest +// and converts them into project.Toolbox config entries. +// Each toolbox resource's options must contain a "tools" array with tool definitions. +func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbox, error) { + if manifest == nil || manifest.Resources == nil { + return nil, nil + } + + var toolboxes []project.Toolbox + + for _, resource := range manifest.Resources { + tbResource, ok := resource.(agent_yaml.ToolboxResource) + if !ok { + continue + } + + description, _ := tbResource.Options["description"].(string) + + rawTools, ok := tbResource.Options["tools"] + if !ok { + return nil, fmt.Errorf( + "toolbox resource '%s' is missing required 'tools' in options", + tbResource.Name, + ) + } + + toolsList, ok := rawTools.([]any) + if !ok { + return nil, fmt.Errorf( + "toolbox resource '%s' has invalid 'tools' format: expected array", + tbResource.Name, + ) + } + + tools := make([]map[string]any, 0, len(toolsList)) + for _, rawTool := range toolsList { + toolMap, ok := rawTool.(map[string]any) + if !ok { + return nil, fmt.Errorf( + "toolbox resource '%s' has invalid tool entry: expected object", + tbResource.Name, + ) + } + tools = append(tools, toolMap) + } + + toolboxes = append(toolboxes, project.Toolbox{ + Name: tbResource.Name, + Description: description, + Tools: tools, + }) + } + + return toolboxes, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go index 00110c1ce23..bc5496a8dea 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers.go @@ -246,6 +246,16 @@ func listProjectDeployments( return results, nil } +// normalizeLoginServer strips any URL scheme prefix (e.g. "https://") from a +// container registry login server so that callers always work with a plain +// hostname like "myregistry.azurecr.io". +func normalizeLoginServer(loginServer string) string { + loginServer = strings.TrimPrefix(loginServer, "https://") + loginServer = strings.TrimPrefix(loginServer, "http://") + loginServer = strings.TrimSuffix(loginServer, "/") + return loginServer +} + // lookupAcrResourceId finds the ARM resource ID for an ACR given its login server endpoint. func lookupAcrResourceId( ctx context.Context, @@ -253,6 +263,7 @@ func lookupAcrResourceId( subscriptionId string, loginServer string, ) (string, error) { + loginServer = normalizeLoginServer(loginServer) parts := strings.Split(loginServer, ".") if len(parts) < 2 || parts[0] == "" { return "", fmt.Errorf("invalid login server format: %q, expected e.g. %q", loginServer, "registry.azurecr.io") @@ -395,12 +406,13 @@ func configureAcrConnection( } if resp.Value != "" { - resourceId, err := lookupAcrResourceId(ctx, credential, subscriptionId, resp.Value) + loginServer := normalizeLoginServer(resp.Value) + resourceId, err := lookupAcrResourceId(ctx, credential, subscriptionId, loginServer) if err != nil { return fmt.Errorf("failed to lookup ACR resource ID: %w", err) } - if err := setEnvValue(ctx, azdClient, envName, "AZURE_CONTAINER_REGISTRY_ENDPOINT", resp.Value); err != nil { + if err := setEnvValue(ctx, azdClient, envName, "AZURE_CONTAINER_REGISTRY_ENDPOINT", loginServer); err != nil { return err } if err := setEnvValue(ctx, azdClient, envName, "AZURE_CONTAINER_REGISTRY_RESOURCE_ID", resourceId); err != nil { @@ -443,7 +455,9 @@ func configureAcrConnection( if err := setEnvValue(ctx, azdClient, envName, "AZURE_AI_PROJECT_ACR_CONNECTION_NAME", selectedConnection.Name); err != nil { return err } - if err := setEnvValue(ctx, azdClient, envName, "AZURE_CONTAINER_REGISTRY_ENDPOINT", selectedConnection.Target); err != nil { + if err := setEnvValue( + ctx, azdClient, envName, "AZURE_CONTAINER_REGISTRY_ENDPOINT", normalizeLoginServer(selectedConnection.Target), + ); err != nil { return err } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go index 3f10f875eeb..4d0e2b525f1 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go @@ -89,6 +89,49 @@ func TestExtractProjectDetails(t *testing.T) { } } +func TestNormalizeLoginServer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + { + name: "plain hostname", + input: "myregistry.azurecr.io", + want: "myregistry.azurecr.io", + }, + { + name: "https prefix", + input: "https://myregistry.azurecr.io", + want: "myregistry.azurecr.io", + }, + { + name: "http prefix", + input: "http://myregistry.azurecr.io", + want: "myregistry.azurecr.io", + }, + { + name: "https with trailing slash", + input: "https://myregistry.azurecr.io/", + want: "myregistry.azurecr.io", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, normalizeLoginServer(tt.input)) + }) + } +} + func TestFoundryProjectInfoResourceIdConstruction(t *testing.T) { t.Parallel() diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index 045e6ee226a..1abeff33a38 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -86,6 +86,13 @@ const ( CodeInvalidFilePath = "invalid_file_path" ) +// Error codes for toolbox/toolset operations. +const ( + CodeInvalidToolbox = "invalid_toolbox" + CodeCreateToolsetFailed = "create_toolset_failed" + CodeUpdateToolsetFailed = "update_toolset_failed" +) + // Error codes commonly used for internal errors. // // These are usually paired with [Internal] for unexpected failures @@ -107,4 +114,7 @@ const ( OpCreateAgent = "create_agent" OpStartContainer = "start_container" OpGetContainerOperation = "get_container_operation" + OpCreateToolset = "create_toolset" + OpUpdateToolset = "update_toolset" + OpGetToolset = "get_toolset" ) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go index f859ff8a31a..318bed100d8 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go @@ -170,6 +170,12 @@ func ExtractResourceDefinitions(manifestYamlContent []byte) ([]any, error) { return nil, fmt.Errorf("failed to unmarshal to ToolResource: %w", err) } resourceDefs = append(resourceDefs, toolDef) + case ResourceKindToolbox: + var toolboxDef ToolboxResource + if err := yaml.Unmarshal(resourceBytes, &toolboxDef); err != nil { + return nil, fmt.Errorf("failed to unmarshal to ToolboxResource: %w", err) + } + resourceDefs = append(resourceDefs, toolboxDef) default: return nil, fmt.Errorf("unrecognized resource kind: %s", resourceDef.Kind) } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index 3679d528713..0240481903e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -218,3 +218,119 @@ environment_variables: t.Errorf("Expected error message to contain '%s', got '%s'", expectedMsg, err.Error()) } } + +// TestExtractResourceDefinitions_ToolboxResource tests parsing toolbox resources from manifest +func TestExtractResourceDefinitions_ToolboxResource(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: toolbox + name: echo-toolbox + id: echo-toolbox + options: + description: A sample toolbox + tools: + - type: mcp + server_label: github + server_url: https://api.example.com/mcp + project_connection_id: TestKey +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + toolboxRes, ok := resources[0].(ToolboxResource) + if !ok { + t.Fatalf("Expected ToolboxResource, got %T", resources[0]) + } + + if toolboxRes.Name != "echo-toolbox" { + t.Errorf("Expected name 'echo-toolbox', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Kind != ResourceKindToolbox { + t.Errorf("Expected kind 'toolbox', got '%s'", toolboxRes.Kind) + } + + if toolboxRes.Id != "echo-toolbox" { + t.Errorf("Expected id 'echo-toolbox', got '%s'", toolboxRes.Id) + } + + desc, ok := toolboxRes.Options["description"] + if !ok { + t.Fatal("Expected 'description' in options") + } + if desc != "A sample toolbox" { + t.Errorf("Expected description 'A sample toolbox', got '%v'", desc) + } + + toolsVal, ok := toolboxRes.Options["tools"] + if !ok { + t.Fatal("Expected 'tools' in options") + } + toolsSlice, ok := toolsVal.([]any) + if !ok { + t.Fatalf("Expected tools to be a slice, got %T", toolsVal) + } + if len(toolsSlice) != 1 { + t.Fatalf("Expected 1 tool in options, got %d", len(toolsSlice)) + } +} + +// TestExtractResourceDefinitions_MixedResources tests parsing both model and toolbox resources +func TestExtractResourceDefinitions_MixedResources(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: model + name: primary-model + id: gpt-4.1-mini + - kind: toolbox + name: my-toolbox + id: my-toolbox + options: + description: My toolbox +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 2 { + t.Fatalf("Expected 2 resources, got %d", len(resources)) + } + + if _, ok := resources[0].(ModelResource); !ok { + t.Errorf("Expected first resource to be ModelResource, got %T", resources[0]) + } + + toolboxRes, ok := resources[1].(ToolboxResource) + if !ok { + t.Fatalf("Expected second resource to be ToolboxResource, got %T", resources[1]) + } + + if toolboxRes.Name != "my-toolbox" { + t.Errorf("Expected name 'my-toolbox', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Id != "my-toolbox" { + t.Errorf("Expected id 'my-toolbox', got '%s'", toolboxRes.Id) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go index bb9231c7651..5b15ef4647f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go @@ -31,8 +31,9 @@ func ValidAgentKinds() []AgentKind { type ResourceKind string const ( - ResourceKindModel ResourceKind = "model" - ResourceKindTool ResourceKind = "tool" + ResourceKindModel ResourceKind = "model" + ResourceKindTool ResourceKind = "tool" + ResourceKindToolbox ResourceKind = "toolbox" ) type ToolKind string @@ -305,6 +306,14 @@ type ToolResource struct { Options map[string]any `json:"options" yaml:"options"` } +// ToolboxResource Represents a toolbox resource required by the agent. +// A toolbox is a reusable collection of tools that can be deployed as a Foundry Toolset. +type ToolboxResource struct { + Resource `json:",inline" yaml:",inline"` + Id string `json:"id" yaml:"id"` + Options map[string]any `json:"options" yaml:"options"` +} + // Template Template model for defining prompt templates. // This model specifies the rendering engine used for slot filling prompts, // the parser used to process the rendered template into API-compatible format, diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go new file mode 100644 index 00000000000..2d6b2d6e6b2 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "azureaiagent/internal/version" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +const ( + toolsetsApiVersion = "v1" + toolsetsFeatureHeader = "Toolsets=V1Preview" +) + +// FoundryToolsetsClient provides methods for interacting with the Foundry Toolsets API +type FoundryToolsetsClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewFoundryToolsetsClient creates a new FoundryToolsetsClient +func NewFoundryToolsetsClient( + endpoint string, + cred azcore.TokenCredential, +) *FoundryToolsetsClient { + userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) + + clientOptions := &policy.ClientOptions{ + Logging: policy.LogOptions{ + AllowedHeaders: []string{azsdk.MsCorrelationIdHeader}, + IncludeBody: true, + }, + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy(cred, []string{"https://ai.azure.com/.default"}, nil), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy(userAgent), + }, + } + + pipeline := runtime.NewPipeline( + "azure-ai-agents", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &FoundryToolsetsClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// CreateToolsetRequest is the request body for creating a toolset +type CreateToolsetRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// UpdateToolsetRequest is the request body for updating a toolset +type UpdateToolsetRequest struct { + Description string `json:"description,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// ToolsetObject is the response object for a toolset +type ToolsetObject struct { + Object string `json:"object"` + Id string `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// DeleteToolsetResponse is the response for deleting a toolset +type DeleteToolsetResponse struct { + Object string `json:"object"` + Name string `json:"name"` + Deleted bool `json:"deleted"` +} + +// CreateToolset creates a new toolset +func (c *FoundryToolsetsClient) CreateToolset( + ctx context.Context, + request *CreateToolsetRequest, +) (*ToolsetObject, error) { + targetUrl := fmt.Sprintf( + "%s/toolsets?api-version=%s", + c.endpoint, toolsetsApiVersion, + ) + + payload, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, targetUrl) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) + + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), + "application/json", + ); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var toolset ToolsetObject + if err := json.Unmarshal(body, &toolset); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &toolset, nil +} + +// UpdateToolset updates an existing toolset +func (c *FoundryToolsetsClient) UpdateToolset( + ctx context.Context, + toolsetName string, + request *UpdateToolsetRequest, +) (*ToolsetObject, error) { + targetUrl := fmt.Sprintf( + "%s/toolsets/%s?api-version=%s", + c.endpoint, url.PathEscape(toolsetName), toolsetsApiVersion, + ) + + payload, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, targetUrl) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) + + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), + "application/json", + ); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var toolset ToolsetObject + if err := json.Unmarshal(body, &toolset); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &toolset, nil +} + +// GetToolset retrieves a toolset by name +func (c *FoundryToolsetsClient) GetToolset( + ctx context.Context, + toolsetName string, +) (*ToolsetObject, error) { + targetUrl := fmt.Sprintf( + "%s/toolsets/%s?api-version=%s", + c.endpoint, url.PathEscape(toolsetName), toolsetsApiVersion, + ) + + req, err := runtime.NewRequest(ctx, http.MethodGet, targetUrl) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var toolset ToolsetObject + if err := json.Unmarshal(body, &toolset); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &toolset, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config.go b/cli/azd/extensions/azure.ai.agents/internal/project/config.go index 4694d649466..81a9ed1f195 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/config.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config.go @@ -48,6 +48,7 @@ type ServiceTargetAgentConfig struct { Container *ContainerSettings `json:"container,omitempty"` Deployments []Deployment `json:"deployments,omitempty"` Resources []Resource `json:"resources,omitempty"` + Toolboxes []Toolbox `json:"toolboxes,omitempty"` StartupCommand string `json:"startupCommand,omitempty"` } @@ -108,6 +109,13 @@ type Resource struct { ConnectionName string `json:"connectionName"` } +// Toolbox represents a reusable collection of tools deployed as a Foundry Toolset +type Toolbox struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Tools []map[string]any `json:"tools"` +} + // UnmarshalStruct converts a structpb.Struct to a Go struct of type T func UnmarshalStruct[T any](s *structpb.Struct, out *T) error { structBytes, err := protojson.Marshal(s) diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go new file mode 100644 index 00000000000..e6012afe8df --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "testing" +) + +// TestServiceTargetAgentConfig_WithToolboxes tests MarshalStruct/UnmarshalStruct round-trip +// with the Toolboxes field populated. +func TestServiceTargetAgentConfig_WithToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{ + { + Name: "echo-toolbox", + Description: "A sample toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "github", + "server_url": "https://api.example.com/mcp", + "project_connection_id": "TestKey", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + + tb := roundTripped.Toolboxes[0] + if tb.Name != "echo-toolbox" { + t.Errorf("Expected toolbox name 'echo-toolbox', got '%s'", tb.Name) + } + if tb.Description != "A sample toolbox" { + t.Errorf("Expected description 'A sample toolbox', got '%s'", tb.Description) + } + if len(tb.Tools) != 1 { + t.Fatalf("Expected 1 tool in toolbox, got %d", len(tb.Tools)) + } + + tool := tb.Tools[0] + if tool["type"] != "mcp" { + t.Errorf("Expected tool type 'mcp', got '%v'", tool["type"]) + } + if tool["server_label"] != "github" { + t.Errorf("Expected server_label 'github', got '%v'", tool["server_label"]) + } +} + +// TestServiceTargetAgentConfig_EmptyToolboxes tests that an empty Toolboxes slice round-trips correctly. +func TestServiceTargetAgentConfig_EmptyToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{}, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 0 { + t.Errorf("Expected 0 toolboxes, got %d", len(roundTripped.Toolboxes)) + } +} + +// TestServiceTargetAgentConfig_MultipleToolboxes tests round-tripping multiple toolboxes. +func TestServiceTargetAgentConfig_MultipleToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{ + { + Name: "toolbox-one", + Description: "First toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "server-a", + "project_connection_id": "KeyA", + }, + }, + }, + { + Name: "toolbox-two", + Description: "Second toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "server-b", + "project_connection_id": "KeyB", + }, + { + "type": "mcp", + "server_label": "server-c", + "project_connection_id": "KeyC", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 2 { + t.Fatalf("Expected 2 toolboxes, got %d", len(roundTripped.Toolboxes)) + } + + if roundTripped.Toolboxes[0].Name != "toolbox-one" { + t.Errorf("Expected first toolbox name 'toolbox-one', got '%s'", roundTripped.Toolboxes[0].Name) + } + + if roundTripped.Toolboxes[1].Name != "toolbox-two" { + t.Errorf("Expected second toolbox name 'toolbox-two', got '%s'", roundTripped.Toolboxes[1].Name) + } + + if len(roundTripped.Toolboxes[1].Tools) != 2 { + t.Errorf("Expected 2 tools in second toolbox, got %d", len(roundTripped.Toolboxes[1].Tools)) + } +} + +// TestServiceTargetAgentConfig_WithOtherFields tests that Toolboxes coexists correctly +// alongside other ServiceTargetAgentConfig fields. +func TestServiceTargetAgentConfig_WithOtherFields(t *testing.T) { + original := ServiceTargetAgentConfig{ + Environment: map[string]string{"KEY": "VALUE"}, + Deployments: []Deployment{ + { + Name: "test-deployment", + Model: DeploymentModel{ + Name: "gpt-4.1-mini", + Format: "OpenAI", + Version: "2025-04-14", + }, + Sku: DeploymentSku{ + Name: "Standard", + Capacity: 10, + }, + }, + }, + Toolboxes: []Toolbox{ + { + Name: "my-toolbox", + Description: "Coexisting toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "test-server", + "project_connection_id": "TestConn", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if roundTripped.Environment["KEY"] != "VALUE" { + t.Errorf("Expected env KEY=VALUE, got '%s'", roundTripped.Environment["KEY"]) + } + + if len(roundTripped.Deployments) != 1 { + t.Fatalf("Expected 1 deployment, got %d", len(roundTripped.Deployments)) + } + + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + + if roundTripped.Toolboxes[0].Name != "my-toolbox" { + t.Errorf("Expected toolbox name 'my-toolbox', got '%s'", roundTripped.Toolboxes[0].Name) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index 367e7fd003d..2e9adadb189 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -6,8 +6,10 @@ package project import ( "context" "encoding/base64" + "errors" "fmt" "math" + "net/http" "os" "path/filepath" "strconv" @@ -19,6 +21,7 @@ import ( "azureaiagent/internal/pkg/agents/agent_yaml" "azureaiagent/internal/pkg/azure" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" @@ -433,6 +436,11 @@ func (p *AgentServiceTargetProvider) Deploy( fmt.Println("Loaded custom service target configuration") } + // Deploy toolboxes before agent creation + if err := p.deployToolboxes(ctx, serviceTargetConfig, azdEnv); err != nil { + return nil, err + } + // Load and validate the agent manifest data, err := os.ReadFile(p.agentDefinitionPath) if err != nil { @@ -1075,6 +1083,110 @@ func (p *AgentServiceTargetProvider) resolveEnvironmentVariables(value string, a return resolved } +// deployToolboxes creates or updates Foundry Toolsets for each toolbox in the config. +// For each toolbox, it calls the Foundry Toolsets API to upsert the toolset, then +// sets environment variables with the MCP endpoints from the toolset's MCP tools. +func (p *AgentServiceTargetProvider) deployToolboxes( + ctx context.Context, + serviceTargetConfig *ServiceTargetAgentConfig, + azdEnv map[string]string, +) error { + if serviceTargetConfig == nil || len(serviceTargetConfig.Toolboxes) == 0 { + return nil + } + + projectEndpoint := azdEnv["AZURE_AI_PROJECT_ENDPOINT"] + if projectEndpoint == "" { + return exterrors.Dependency( + exterrors.CodeMissingAiProjectEndpoint, + "AZURE_AI_PROJECT_ENDPOINT is required for toolbox deployment", + "run 'azd provision' or connect to an existing project", + ) + } + + toolsetsClient := azure.NewFoundryToolsetsClient(projectEndpoint, p.credential) + + for _, toolbox := range serviceTargetConfig.Toolboxes { + fmt.Fprintf(os.Stderr, "Deploying toolbox: %s\n", toolbox.Name) + + _, err := p.upsertToolset(ctx, toolsetsClient, toolbox) + if err != nil { + return err + } + + // Set the MCP endpoint env var now that the toolbox is confirmed to exist + if err := p.registerToolboxEnvironmentVariables( + ctx, projectEndpoint, toolbox.Name, + ); err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "Toolbox '%s' deployed successfully\n", toolbox.Name) + } + + return nil +} + +// upsertToolset creates a toolset, or updates it if it already exists. +// A 409 Conflict on create means the toolset already exists, which is treated as success. +func (p *AgentServiceTargetProvider) upsertToolset( + ctx context.Context, + client *azure.FoundryToolsetsClient, + toolbox Toolbox, +) (*azure.ToolsetObject, error) { + createReq := &azure.CreateToolsetRequest{ + Name: toolbox.Name, + Description: toolbox.Description, + Tools: toolbox.Tools, + } + + toolset, err := client.CreateToolset(ctx, createReq) + if err == nil { + return toolset, nil + } + + // 409 Conflict means the toolset already exists — treat as success + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + fmt.Fprintf(os.Stderr, " Toolset '%s' already exists, skipping update\n", toolbox.Name) + return nil, nil + } + + return nil, exterrors.Internal( + exterrors.CodeCreateToolsetFailed, + fmt.Sprintf("failed to create toolset '%s': %s", toolbox.Name, err), + ) +} + +// registerToolboxEnvironmentVariables sets the FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT env var +// with the constructed MCP endpoint URL: {projectEndpoint}/toolsets/{toolboxName}/mcp +func (p *AgentServiceTargetProvider) registerToolboxEnvironmentVariables( + ctx context.Context, + projectEndpoint string, + toolboxName string, +) error { + toolboxKey := p.getServiceKey(toolboxName) + envKey := fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) + + endpoint := strings.TrimRight(projectEndpoint, "/") + mcpEndpoint := fmt.Sprintf("%s/toolsets/%s/mcp", endpoint, toolboxName) + + _, err := p.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: p.env.Name, + Key: envKey, + Value: mcpEndpoint, + }) + if err != nil { + return fmt.Errorf( + "failed to set environment variable %s: %w", + envKey, err, + ) + } + + fmt.Fprintf(os.Stderr, " Set %s=%s\n", envKey, mcpEndpoint) + return nil +} + // ensureFoundryProject ensures the Foundry project resource ID is parsed and stored. // Checks for AZURE_AI_PROJECT_ID environment variable. func (p *AgentServiceTargetProvider) ensureFoundryProject(ctx context.Context) error { diff --git a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json index 22d77ce8341..722a3b49359 100644 --- a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json +++ b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json @@ -24,6 +24,11 @@ "description": "List of external resources for agent execution.", "items": { "$ref": "#/definitions/Resource" } }, + "toolboxes": { + "type": "array", + "description": "List of toolboxes (Foundry Toolsets) to deploy.", + "items": { "$ref": "#/definitions/Toolbox" } + }, "startupCommand": { "type": "string", "description": "Command to start the agent server (e.g., 'python main.py'). Used by 'azd ai agent run' for local development." @@ -119,6 +124,21 @@ }, "required": ["resource", "connectionName"], "additionalProperties": false + }, + "Toolbox": { + "type": "object", + "description": "A reusable collection of tools deployed as a Foundry Toolset.", + "properties": { + "name": { "type": "string", "description": "Name of the toolbox." }, + "description": { "type": "string", "description": "Description of the toolbox." }, + "tools": { + "type": "array", + "description": "List of tools in the toolbox. Each tool is an object with properties passed to the Foundry Toolsets API.", + "items": { "type": "object" } + } + }, + "required": ["name", "tools"], + "additionalProperties": false } } } \ No newline at end of file From ce540de2ae4b297ac5a5bc8286d5e90817011cb1 Mon Sep 17 00:00:00 2001 From: trangevi Date: Tue, 7 Apr 2026 17:26:11 -0700 Subject: [PATCH 02/14] Attempting to handle new yaml Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/init.go | 66 ++- .../azure.ai.agents/internal/cmd/listen.go | 215 +++++++++ .../internal/exterrors/codes.go | 8 + .../internal/pkg/agents/agent_yaml/parse.go | 18 + .../pkg/agents/agent_yaml/parse_test.go | 416 ++++++++++++++++-- .../internal/pkg/agents/agent_yaml/yaml.go | 185 +++++++- .../pkg/agents/agent_yaml/yaml_test.go | 232 ++++++++++ .../internal/project/config.go | 18 + .../internal/project/service_target_agent.go | 5 - .../schemas/azure.ai.agent.json | 41 ++ 10 files changed, 1142 insertions(+), 62 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index 0334c657103..5f3c72ebf7f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -1170,6 +1170,13 @@ func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentMa } agentConfig.Toolboxes = toolboxes + // Process connection resources from the manifest + connections, err := extractConnectionConfigs(agentManifest) + if err != nil { + return err + } + agentConfig.Connections = connections + // Detect startup command from the project source directory startupCmd, err := resolveStartupCommandForInit(ctx, a.azdClient, a.projectConfig.Path, targetDir, a.flags.NoPrompt) if err != nil { @@ -1576,7 +1583,6 @@ func downloadDirectoryContentsWithoutGhCli( // extractToolboxConfigs extracts toolbox resource definitions from the agent manifest // and converts them into project.Toolbox config entries. -// Each toolbox resource's options must contain a "tools" array with tool definitions. func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbox, error) { if manifest == nil || manifest.Resources == nil { return nil, nil @@ -1590,26 +1596,15 @@ func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbo continue } - description, _ := tbResource.Options["description"].(string) - - rawTools, ok := tbResource.Options["tools"] - if !ok { - return nil, fmt.Errorf( - "toolbox resource '%s' is missing required 'tools' in options", - tbResource.Name, - ) - } - - toolsList, ok := rawTools.([]any) - if !ok { + if len(tbResource.Tools) == 0 { return nil, fmt.Errorf( - "toolbox resource '%s' has invalid 'tools' format: expected array", + "toolbox resource '%s' is missing required 'tools'", tbResource.Name, ) } - tools := make([]map[string]any, 0, len(toolsList)) - for _, rawTool := range toolsList { + tools := make([]map[string]any, 0, len(tbResource.Tools)) + for _, rawTool := range tbResource.Tools { toolMap, ok := rawTool.(map[string]any) if !ok { return nil, fmt.Errorf( @@ -1622,10 +1617,47 @@ func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbo toolboxes = append(toolboxes, project.Toolbox{ Name: tbResource.Name, - Description: description, + Description: tbResource.Description, Tools: tools, }) } return toolboxes, nil } + +// extractConnectionConfigs extracts connection resource definitions from the agent manifest +// and converts them into project.Connection config entries. +func extractConnectionConfigs(manifest *agent_yaml.AgentManifest) ([]project.Connection, error) { + if manifest == nil || manifest.Resources == nil { + return nil, nil + } + + var connections []project.Connection + + for _, resource := range manifest.Resources { + connResource, ok := resource.(agent_yaml.ConnectionResource) + if !ok { + continue + } + + conn := project.Connection{ + Name: connResource.Name, + Category: string(connResource.Category), + Target: connResource.Target, + AuthType: string(connResource.AuthType), + Credentials: connResource.Credentials, + Metadata: connResource.Metadata, + ExpiryTime: connResource.ExpiryTime, + IsSharedToAll: connResource.IsSharedToAll, + SharedUserList: connResource.SharedUserList, + PeRequirement: connResource.PeRequirement, + PeStatus: connResource.PeStatus, + UseWorkspaceManagedIdentity: connResource.UseWorkspaceManagedIdentity, + Error: connResource.Error, + } + + connections = append(connections, conn) + } + + return connections, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index a4f8b5ecdee..e2d3b473b3d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -6,14 +6,20 @@ package cmd import ( "context" "encoding/json" + "errors" "fmt" + "net/http" "os" "path/filepath" "strings" + "azureaiagent/internal/exterrors" "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/azure" "azureaiagent/internal/project" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" "github.com/spf13/cobra" @@ -47,6 +53,9 @@ func newListenCommand() *cobra.Command { WithProjectEventHandler("preprovision", func(ctx context.Context, args *azdext.ProjectEventArgs) error { return preprovisionHandler(ctx, azdClient, projectParser, args) }). + WithProjectEventHandler("postprovision", func(ctx context.Context, args *azdext.ProjectEventArgs) error { + return postprovisionHandler(ctx, azdClient, args) + }). WithProjectEventHandler("predeploy", func(ctx context.Context, args *azdext.ProjectEventArgs) error { return predeployHandler(ctx, azdClient, projectParser, args) }). @@ -89,6 +98,27 @@ func preprovisionHandler(ctx context.Context, azdClient *azdext.AzdClient, proje return nil } +func postprovisionHandler( + ctx context.Context, + azdClient *azdext.AzdClient, + args *azdext.ProjectEventArgs, +) error { + for _, svc := range args.Project.Services { + if svc.Host != AiAgentHost { + continue + } + + if err := provisionToolboxes(ctx, azdClient, svc); err != nil { + return fmt.Errorf( + "failed to provision toolboxes for service %q: %w", + svc.Name, err, + ) + } + } + + return nil +} + func predeployHandler(ctx context.Context, azdClient *azdext.AzdClient, projectParser *project.FoundryParser, args *azdext.ProjectEventArgs) error { if err := projectParser.SetIdentity(ctx, args); err != nil { return fmt.Errorf("failed to set identity: %w", err) @@ -152,6 +182,15 @@ func envUpdate(ctx context.Context, azdClient *azdext.AzdClient, azdProject *azd } } + if len(foundryAgentConfig.Connections) > 0 { + if err := connectionsEnvUpdate( + ctx, foundryAgentConfig.Connections, + azdClient, currentEnvResponse.Environment.Name, + ); err != nil { + return err + } + } + return nil } @@ -219,6 +258,25 @@ func resourcesEnvUpdate(ctx context.Context, resources []project.Resource, azdCl return setEnvVar(ctx, azdClient, envName, "AI_PROJECT_DEPENDENT_RESOURCES", escapedJsonString) } +func connectionsEnvUpdate( + ctx context.Context, + connections []project.Connection, + azdClient *azdext.AzdClient, + envName string, +) error { + connectionsJson, err := json.Marshal(connections) + if err != nil { + return fmt.Errorf("failed to marshal connection details to JSON: %w", err) + } + + // Escape backslashes and double quotes for environment variable + jsonString := string(connectionsJson) + escapedJsonString := strings.ReplaceAll(jsonString, "\\", "\\\\") + escapedJsonString = strings.ReplaceAll(escapedJsonString, "\"", "\\\"") + + return setEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", escapedJsonString) +} + func containerAgentHandling(ctx context.Context, azdClient *azdext.AzdClient, project *azdext.ProjectConfig, svc *azdext.ServiceConfig) error { servicePath := svc.RelativePath fullPath := filepath.Join(project.Path, servicePath) @@ -335,3 +393,160 @@ func populateContainerSettings(ctx context.Context, azdClient *azdext.AzdClient, return nil } + +// provisionToolboxes creates or updates Foundry Toolsets for each toolbox +// in the service config. Called during post-provision after the project +// endpoint has been created by Bicep. +func provisionToolboxes( + ctx context.Context, + azdClient *azdext.AzdClient, + svc *azdext.ServiceConfig, +) error { + var config *project.ServiceTargetAgentConfig + if err := project.UnmarshalStruct(svc.Config, &config); err != nil { + return fmt.Errorf("failed to parse service config: %w", err) + } + + if config == nil || len(config.Toolboxes) == 0 { + return nil + } + + currentEnv, err := azdClient.Environment().GetCurrent( + ctx, &azdext.EmptyRequest{}, + ) + if err != nil { + return fmt.Errorf("failed to get current environment: %w", err) + } + + envValue, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: currentEnv.Environment.Name, + Key: "AZURE_AI_PROJECT_ENDPOINT", + }) + if err != nil || envValue.Value == "" { + return exterrors.Dependency( + exterrors.CodeMissingAiProjectEndpoint, + "AZURE_AI_PROJECT_ENDPOINT is required for toolbox provisioning", + "run 'azd provision' to create the AI project first", + ) + } + projectEndpoint := envValue.Value + + envValue, err = azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: currentEnv.Environment.Name, + Key: "AZURE_TENANT_ID", + }) + if err != nil || envValue.Value == "" { + return exterrors.Dependency( + exterrors.CodeMissingAzureTenantId, + "AZURE_TENANT_ID is required for toolbox provisioning", + "run 'azd auth login' to authenticate", + ) + } + tenantId := envValue.Value + + cred, err := azidentity.NewAzureDeveloperCLICredential( + &azidentity.AzureDeveloperCLICredentialOptions{ + TenantID: tenantId, + AdditionallyAllowedTenants: []string{"*"}, + }, + ) + if err != nil { + return exterrors.Auth( + exterrors.CodeCredentialCreationFailed, + fmt.Sprintf("failed to create credential: %s", err), + "run 'azd auth login' to authenticate", + ) + } + + toolsetsClient := azure.NewFoundryToolsetsClient( + projectEndpoint, cred, + ) + + for _, toolbox := range config.Toolboxes { + fmt.Fprintf( + os.Stderr, "Provisioning toolbox: %s\n", toolbox.Name, + ) + + if err := upsertToolset( + ctx, toolsetsClient, toolbox, + ); err != nil { + return err + } + + if err := registerToolboxEnvVars( + ctx, azdClient, + currentEnv.Environment.Name, + projectEndpoint, toolbox.Name, + ); err != nil { + return err + } + + fmt.Fprintf( + os.Stderr, "Toolbox '%s' provisioned\n", toolbox.Name, + ) + } + + return nil +} + +// upsertToolset creates a toolset, or skips if it already exists. +func upsertToolset( + ctx context.Context, + client *azure.FoundryToolsetsClient, + toolbox project.Toolbox, +) error { + createReq := &azure.CreateToolsetRequest{ + Name: toolbox.Name, + Description: toolbox.Description, + Tools: toolbox.Tools, + } + + _, err := client.CreateToolset(ctx, createReq) + if err == nil { + return nil + } + + // 409 Conflict means the toolset already exists + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + fmt.Fprintf( + os.Stderr, + " Toolset '%s' already exists, skipping\n", + toolbox.Name, + ) + return nil + } + + return exterrors.Internal( + exterrors.CodeCreateToolsetFailed, + fmt.Sprintf( + "failed to create toolset '%s': %s", + toolbox.Name, err, + ), + ) +} + +// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT. +func registerToolboxEnvVars( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, + projectEndpoint string, + toolboxName string, +) error { + key := strings.ToUpper( + strings.ReplaceAll(toolboxName, "-", "_"), + ) + envKey := fmt.Sprintf( + "TOOLBOX_%s_MCP_ENDPOINT", key, + ) + + endpoint := strings.TrimRight(projectEndpoint, "/") + mcpEndpoint := fmt.Sprintf( + "%s/toolsets/%s/mcp", endpoint, toolboxName, + ) + + return setEnvVar( + ctx, azdClient, envName, envKey, mcpEndpoint, + ) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index 1abeff33a38..1a594229172 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -43,6 +43,7 @@ const ( CodeEnvironmentCreationFailed = "environment_creation_failed" CodeEnvironmentValuesFailed = "environment_values_failed" CodeMissingAiProjectEndpoint = "missing_ai_project_endpoint" + CodeMissingAzureTenantId = "missing_azure_tenant_id" CodeMissingAiProjectId = "missing_ai_project_id" CodeMissingAzureSubscription = "missing_azure_subscription_id" CodeMissingAgentEnvVars = "missing_agent_env_vars" @@ -93,6 +94,13 @@ const ( CodeUpdateToolsetFailed = "update_toolset_failed" ) +// Error codes for connection operations. +const ( + CodeInvalidConnection = "invalid_connection" + CodeConnectionCreationFail = "connection_creation_failed" + CodeMissingConnectionField = "missing_connection_field" +) + // Error codes commonly used for internal errors. // // These are usually paired with [Internal] for unexpected failures diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go index 318bed100d8..3048bed1f6e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go @@ -176,6 +176,12 @@ func ExtractResourceDefinitions(manifestYamlContent []byte) ([]any, error) { return nil, fmt.Errorf("failed to unmarshal to ToolboxResource: %w", err) } resourceDefs = append(resourceDefs, toolboxDef) + case ResourceKindConnection: + var connDef ConnectionResource + if err := yaml.Unmarshal(resourceBytes, &connDef); err != nil { + return nil, fmt.Errorf("failed to unmarshal to ConnectionResource: %w", err) + } + resourceDefs = append(resourceDefs, connDef) default: return nil, fmt.Errorf("unrecognized resource kind: %s", resourceDef.Kind) } @@ -302,6 +308,18 @@ func ExtractToolsDefinitions(template map[string]any) ([]any, error) { return nil, fmt.Errorf("failed to unmarshal to CodeInterpreterTool: %w", err) } tools = append(tools, codeInterpreterTool) + case ToolKindAzureAiSearch: + var azureAiSearchTool AzureAISearchTool + if err := yaml.Unmarshal(toolBytes, &azureAiSearchTool); err != nil { + return nil, fmt.Errorf("failed to unmarshal to AzureAISearchTool: %w", err) + } + tools = append(tools, azureAiSearchTool) + case ToolKindA2APreview: + var a2aTool A2APreviewTool + if err := yaml.Unmarshal(toolBytes, &a2aTool); err != nil { + return nil, fmt.Errorf("failed to unmarshal to A2APreviewTool: %w", err) + } + tools = append(tools, a2aTool) default: return nil, fmt.Errorf("unrecognized tool kind: %s", toolDef.Kind) } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index 0240481903e..c0c4a97fe52 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -231,14 +231,12 @@ template: resources: - kind: toolbox name: echo-toolbox - id: echo-toolbox - options: - description: A sample toolbox - tools: - - type: mcp - server_label: github - server_url: https://api.example.com/mcp - project_connection_id: TestKey + description: A sample toolbox + tools: + - type: mcp + server_label: github + server_url: https://api.example.com/mcp + project_connection_id: TestKey `) resources, err := ExtractResourceDefinitions(yamlContent) @@ -263,28 +261,12 @@ resources: t.Errorf("Expected kind 'toolbox', got '%s'", toolboxRes.Kind) } - if toolboxRes.Id != "echo-toolbox" { - t.Errorf("Expected id 'echo-toolbox', got '%s'", toolboxRes.Id) + if toolboxRes.Description != "A sample toolbox" { + t.Errorf("Expected description 'A sample toolbox', got '%s'", toolboxRes.Description) } - desc, ok := toolboxRes.Options["description"] - if !ok { - t.Fatal("Expected 'description' in options") - } - if desc != "A sample toolbox" { - t.Errorf("Expected description 'A sample toolbox', got '%v'", desc) - } - - toolsVal, ok := toolboxRes.Options["tools"] - if !ok { - t.Fatal("Expected 'tools' in options") - } - toolsSlice, ok := toolsVal.([]any) - if !ok { - t.Fatalf("Expected tools to be a slice, got %T", toolsVal) - } - if len(toolsSlice) != 1 { - t.Fatalf("Expected 1 tool in options, got %d", len(toolsSlice)) + if len(toolboxRes.Tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(toolboxRes.Tools)) } } @@ -303,9 +285,9 @@ resources: id: gpt-4.1-mini - kind: toolbox name: my-toolbox - id: my-toolbox - options: - description: My toolbox + description: My toolbox + tools: + - type: web_search `) resources, err := ExtractResourceDefinitions(yamlContent) @@ -330,7 +312,375 @@ resources: t.Errorf("Expected name 'my-toolbox', got '%s'", toolboxRes.Name) } - if toolboxRes.Id != "my-toolbox" { - t.Errorf("Expected id 'my-toolbox', got '%s'", toolboxRes.Id) + if toolboxRes.Description != "My toolbox" { + t.Errorf("Expected description 'My toolbox', got '%s'", toolboxRes.Description) + } +} + +// TestExtractResourceDefinitions_ConnectionResource tests parsing connection resources +func TestExtractResourceDefinitions_ConnectionResource(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: context7 + category: CustomKeys + target: https://mcp.context7.com/mcp + authType: CustomKeys + credentials: + key: my-api-key + metadata: + source: context7 +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + connRes, ok := resources[0].(ConnectionResource) + if !ok { + t.Fatalf("Expected ConnectionResource, got %T", resources[0]) + } + + if connRes.Name != "context7" { + t.Errorf("Expected name 'context7', got '%s'", connRes.Name) + } + if connRes.Kind != ResourceKindConnection { + t.Errorf("Expected kind 'connection', got '%s'", connRes.Kind) + } + if connRes.Category != CategoryCustomKeys { + t.Errorf("Expected category 'CustomKeys', got '%s'", connRes.Category) + } + if connRes.Target != "https://mcp.context7.com/mcp" { + t.Errorf("Expected target 'https://mcp.context7.com/mcp', got '%s'", connRes.Target) + } + if connRes.AuthType != AuthTypeCustomKeys { + t.Errorf("Expected authType 'CustomKeys', got '%s'", connRes.AuthType) + } + if connRes.Credentials["key"] != "my-api-key" { + t.Errorf("Expected credentials.key 'my-api-key', got '%v'", connRes.Credentials["key"]) + } + if connRes.Metadata["source"] != "context7" { + t.Errorf("Expected metadata.source 'context7', got '%s'", connRes.Metadata["source"]) + } +} + +// TestExtractResourceDefinitions_ConnectionMissingRequired tests that missing required fields fail validation +func TestExtractResourceDefinitions_ConnectionMissingRequired(t *testing.T) { + testCases := []struct { + name string + yaml string + errMatch string + }{ + { + name: "missing category", + yaml: ` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: my-conn + target: https://example.com + authType: None +`, + errMatch: "category is required", + }, + { + name: "missing target", + yaml: ` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: my-conn + category: CustomKeys + authType: None +`, + errMatch: "target is required", + }, + { + name: "missing authType", + yaml: ` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: my-conn + category: CustomKeys + target: https://example.com +`, + errMatch: "authType is required", + }, + { + name: "invalid authType", + yaml: ` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: my-conn + category: CustomKeys + target: https://example.com + authType: InvalidAuth +`, + errMatch: "authType must be one of", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := ExtractResourceDefinitions([]byte(tc.yaml)) + if err == nil { + t.Fatal("Expected validation error, got nil") + } + if !strings.Contains(err.Error(), tc.errMatch) { + t.Errorf("Expected error to contain '%s', got '%s'", tc.errMatch, err.Error()) + } + }) + } +} + +// TestExtractResourceDefinitions_AllResourceKinds tests model + toolbox + connection together +func TestExtractResourceDefinitions_AllResourceKinds(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: model + name: chat + id: gpt-4.1-mini + - kind: connection + name: context7 + category: CustomKeys + target: https://mcp.context7.com/mcp + authType: CustomKeys + credentials: + key: test-key + - kind: toolbox + name: agent-tools + description: MCP tools for documentation search + tools: + - id: web_search + - id: mcp + project_connection_id: context7 +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 3 { + t.Fatalf("Expected 3 resources, got %d", len(resources)) + } + + if _, ok := resources[0].(ModelResource); !ok { + t.Errorf("Expected first resource to be ModelResource, got %T", resources[0]) + } + if _, ok := resources[1].(ConnectionResource); !ok { + t.Errorf("Expected second resource to be ConnectionResource, got %T", resources[1]) + } + if _, ok := resources[2].(ToolboxResource); !ok { + t.Errorf("Expected third resource to be ToolboxResource, got %T", resources[2]) + } +} + +// TestExtractResourceDefinitions_ConnectionAllAuthTypes tests all supported auth types +func TestExtractResourceDefinitions_ConnectionAllAuthTypes(t *testing.T) { + authTypes := []AuthType{ + AuthTypeAAD, + AuthTypeApiKey, + AuthTypeCustomKeys, + AuthTypeNone, + AuthTypeOAuth2, + AuthTypePAT, + } + + for _, authType := range authTypes { + t.Run(string(authType), func(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: test-conn + category: CustomKeys + target: https://example.com + authType: ` + string(authType) + ` +`) + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("Failed for authType %s: %v", authType, err) + } + + connRes := resources[0].(ConnectionResource) + if connRes.AuthType != authType { + t.Errorf("Expected authType '%s', got '%s'", authType, connRes.AuthType) + } + }) + } +} + +// TestExtractResourceDefinitions_ConnectionOptionalFields tests optional fields are preserved +func TestExtractResourceDefinitions_ConnectionOptionalFields(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: full-conn + category: AzureOpenAI + target: https://myendpoint.openai.azure.com + authType: AAD + expiryTime: "2025-12-31T00:00:00Z" + isSharedToAll: true + sharedUserList: + - user1@contoso.com + - user2@contoso.com + metadata: + env: production + useWorkspaceManagedIdentity: false +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + connRes := resources[0].(ConnectionResource) + + if connRes.ExpiryTime != "2025-12-31T00:00:00Z" { + t.Errorf("Expected expiryTime, got '%s'", connRes.ExpiryTime) + } + if connRes.IsSharedToAll == nil || *connRes.IsSharedToAll != true { + t.Error("Expected isSharedToAll to be true") + } + if len(connRes.SharedUserList) != 2 { + t.Errorf("Expected 2 shared users, got %d", len(connRes.SharedUserList)) + } + if connRes.Metadata["env"] != "production" { + t.Errorf("Expected metadata.env 'production', got '%s'", connRes.Metadata["env"]) + } + if connRes.UseWorkspaceManagedIdentity == nil || *connRes.UseWorkspaceManagedIdentity != false { + t.Error("Expected useWorkspaceManagedIdentity to be false") + } +} + +// TestExtractToolsDefinitions_AzureAiSearch tests parsing an azureAiSearch tool +func TestExtractToolsDefinitions_AzureAiSearch(t *testing.T) { + template := map[string]any{ + "tools": []any{ + map[string]any{ + "name": "my-search", + "kind": "azureAiSearch", + "indexes": []any{ + map[string]any{ + "project_connection_id": "search-conn", + "index_name": "docs-index", + "query_type": "semantic", + "top_k": 10, + }, + }, + }, + }, + } + + tools, err := ExtractToolsDefinitions(template) + if err != nil { + t.Fatalf("ExtractToolsDefinitions failed: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + searchTool, ok := tools[0].(AzureAISearchTool) + if !ok { + t.Fatalf("Expected AzureAISearchTool, got %T", tools[0]) + } + if searchTool.Kind != ToolKindAzureAiSearch { + t.Errorf("Expected kind 'azureAiSearch', got '%s'", searchTool.Kind) + } + if len(searchTool.Indexes) != 1 { + t.Fatalf("Expected 1 index, got %d", len(searchTool.Indexes)) + } + if searchTool.Indexes[0].IndexName != "docs-index" { + t.Errorf("Expected index_name 'docs-index', got '%s'", searchTool.Indexes[0].IndexName) + } +} + +// TestExtractToolsDefinitions_A2APreview tests parsing an a2aPreview tool +func TestExtractToolsDefinitions_A2APreview(t *testing.T) { + template := map[string]any{ + "tools": []any{ + map[string]any{ + "name": "a2a-delegate", + "kind": "a2aPreview", + "baseUrl": "https://remote-agent.example.com", + "agentCardPath": "/.well-known/agent.json", + "projectConnectionId": "remote-conn", + }, + }, + } + + tools, err := ExtractToolsDefinitions(template) + if err != nil { + t.Fatalf("ExtractToolsDefinitions failed: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + a2aTool, ok := tools[0].(A2APreviewTool) + if !ok { + t.Fatalf("Expected A2APreviewTool, got %T", tools[0]) + } + if a2aTool.Kind != ToolKindA2APreview { + t.Errorf("Expected kind 'a2aPreview', got '%s'", a2aTool.Kind) + } + if a2aTool.BaseUrl != "https://remote-agent.example.com" { + t.Errorf("Expected baseUrl, got '%s'", a2aTool.BaseUrl) + } + if a2aTool.ProjectConnectionId != "remote-conn" { + t.Errorf("Expected projectConnectionId 'remote-conn', got '%s'", a2aTool.ProjectConnectionId) } } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go index 5b15ef4647f..b98afe0f862 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go @@ -3,7 +3,9 @@ package agent_yaml -import "slices" +import ( + "slices" +) // AgentKind represents the type of agent type AgentKind string @@ -31,9 +33,10 @@ func ValidAgentKinds() []AgentKind { type ResourceKind string const ( - ResourceKindModel ResourceKind = "model" - ResourceKindTool ResourceKind = "tool" - ResourceKindToolbox ResourceKind = "toolbox" + ResourceKindModel ResourceKind = "model" + ResourceKindTool ResourceKind = "tool" + ResourceKindToolbox ResourceKind = "toolbox" + ResourceKindConnection ResourceKind = "connection" ) type ToolKind string @@ -47,6 +50,100 @@ const ( ToolKindMcp ToolKind = "mcp" ToolKindOpenApi ToolKind = "openApi" ToolKindCodeInterpreter ToolKind = "codeInterpreter" + ToolKindAzureAiSearch ToolKind = "azureAiSearch" + ToolKindA2APreview ToolKind = "a2aPreview" +) + +// toolKindAPITypeMap maps camelCase ToolKind values to snake_case API type values. +var toolKindAPITypeMap = map[ToolKind]string{ + ToolKindFunction: "function", + ToolKindCustom: "custom", + ToolKindWebSearch: "web_search", + ToolKindBingGrounding: "bing_grounding", + ToolKindFileSearch: "file_search", + ToolKindMcp: "mcp", + ToolKindOpenApi: "openapi", + ToolKindCodeInterpreter: "code_interpreter", + ToolKindAzureAiSearch: "azure_ai_search", + ToolKindA2APreview: "a2a_preview", +} + +// ToolKindToAPIType converts a camelCase ToolKind to its snake_case API type. +// Returns the input as-is if no mapping is found. +func ToolKindToAPIType(kind ToolKind) string { + if apiType, ok := toolKindAPITypeMap[kind]; ok { + return apiType + } + return string(kind) +} + +// APITypeToToolKind converts a snake_case API type to its camelCase ToolKind. +// Returns the input as a ToolKind if no mapping is found. +func APITypeToToolKind(apiType string) ToolKind { + for kind, mapped := range toolKindAPITypeMap { + if mapped == apiType { + return kind + } + } + return ToolKind(apiType) +} + +// AuthType represents the authentication type for a connection. +type AuthType string + +const ( + AuthTypeAAD AuthType = "AAD" + AuthTypeApiKey AuthType = "ApiKey" + AuthTypeCustomKeys AuthType = "CustomKeys" + AuthTypeNone AuthType = "None" + AuthTypeOAuth2 AuthType = "OAuth2" + AuthTypePAT AuthType = "PAT" +) + +// ValidAuthTypes returns a slice of all supported AuthType values. +func ValidAuthTypes() []AuthType { + return []AuthType{ + AuthTypeAAD, + AuthTypeApiKey, + AuthTypeCustomKeys, + AuthTypeNone, + AuthTypeOAuth2, + AuthTypePAT, + } +} + +// IsValidAuthType checks if the provided AuthType is valid. +func IsValidAuthType(authType AuthType) bool { + return slices.Contains(ValidAuthTypes(), authType) +} + +// CategoryKind represents the category of a connection resource. +type CategoryKind string + +const ( + CategoryAzureOpenAI CategoryKind = "AzureOpenAI" + CategoryCognitiveSearch CategoryKind = "CognitiveSearch" + CategoryCognitiveService CategoryKind = "CognitiveService" + CategoryCustomKeys CategoryKind = "CustomKeys" + CategoryServerlessEndpoint CategoryKind = "Serverless" + CategoryContainerRegistry CategoryKind = "ContainerRegistry" + CategoryApiKey CategoryKind = "ApiKey" + CategoryAzureBlob CategoryKind = "AzureBlob" + CategoryGit CategoryKind = "Git" + CategoryRedis CategoryKind = "Redis" + CategoryS3 CategoryKind = "S3" + CategorySnowflake CategoryKind = "Snowflake" + CategoryAzureSqlDb CategoryKind = "AzureSqlDb" + CategoryAzureSynapseAnalytics CategoryKind = "AzureSynapseAnalytics" + CategoryAzureMySqlDb CategoryKind = "AzureMySqlDb" + CategoryAzurePostgresDb CategoryKind = "AzurePostgresDb" + CategoryADLSGen2 CategoryKind = "ADLSGen2" + CategoryAzureDataExplorer CategoryKind = "AzureDataExplorer" + CategoryBingLLMSearch CategoryKind = "BingLLMSearch" + CategoryMicrosoftOneLake CategoryKind = "MicrosoftOneLake" + CategoryElasticSearch CategoryKind = "Elasticsearch" + CategoryPinecone CategoryKind = "Pinecone" + CategoryQdrant CategoryKind = "Qdrant" ) type ConnectionKind string @@ -309,9 +406,29 @@ type ToolResource struct { // ToolboxResource Represents a toolbox resource required by the agent. // A toolbox is a reusable collection of tools that can be deployed as a Foundry Toolset. type ToolboxResource struct { - Resource `json:",inline" yaml:",inline"` - Id string `json:"id" yaml:"id"` - Options map[string]any `json:"options" yaml:"options"` + Resource `json:",inline" yaml:",inline"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Tools []any `json:"tools" yaml:"tools"` +} + +// ConnectionResource Represents a connection resource required by the agent. +// Maps to the Bicep ConnectionPropertiesV2 spec for creating project connections. +type ConnectionResource struct { + Resource `json:",inline" yaml:",inline"` + Category CategoryKind `json:"category" yaml:"category"` + Target string `json:"target" yaml:"target"` + AuthType AuthType `json:"authType" yaml:"authType"` + Credentials map[string]any `json:"credentials,omitempty" yaml:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty" yaml:"metadata,omitempty"` + ExpiryTime string `json:"expiryTime,omitempty" yaml:"expiryTime,omitempty"` + IsSharedToAll *bool `json:"isSharedToAll,omitempty" yaml:"isSharedToAll,omitempty"` + SharedUserList []string `json:"sharedUserList,omitempty" yaml:"sharedUserList,omitempty"` + PeRequirement string `json:"peRequirement,omitempty" yaml:"peRequirement,omitempty"` + PeStatus string `json:"peStatus,omitempty" yaml:"peStatus,omitempty"` + Error string `json:"error,omitempty" yaml:"error,omitempty"` + + // UseWorkspaceManagedIdentity indicates whether to use workspace managed identity. + UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty" yaml:"useWorkspaceManagedIdentity,omitempty"` //nolint:lll } // Template Template model for defining prompt templates. @@ -405,3 +522,57 @@ type CodeInterpreterTool struct { FileIds []string `json:"fileIds" yaml:"fileIds"` Options map[string]any `json:"options" yaml:"options"` } + +// AzureAISearchIndex represents a single index configuration within an AzureAISearchTool. +type AzureAISearchIndex struct { + ProjectConnectionId string `json:"project_connection_id" yaml:"project_connection_id"` + IndexName string `json:"index_name" yaml:"index_name"` + QueryType *string `json:"query_type,omitempty" yaml:"query_type,omitempty"` + TopK *int `json:"top_k,omitempty" yaml:"top_k,omitempty"` + Filter *string `json:"filter,omitempty" yaml:"filter,omitempty"` +} + +// AzureAISearchTool The Azure AI Search tool for grounding agent responses with search index data. +type AzureAISearchTool struct { + Tool `json:",inline" yaml:",inline"` + Indexes []AzureAISearchIndex `json:"indexes" yaml:"indexes"` +} + +// A2APreviewTool The A2A (Agent-to-Agent) preview tool for delegating tasks to other agents. +type A2APreviewTool struct { + Tool `json:",inline" yaml:",inline"` + BaseUrl string `json:"baseUrl" yaml:"baseUrl"` + AgentCardPath *string `json:"agentCardPath,omitempty" yaml:"agentCardPath,omitempty"` + ProjectConnectionId string `json:"projectConnectionId" yaml:"projectConnectionId"` +} + +// Credential type structs for typed access to connection credentials. +// The ConnectionResource.Credentials field is map[string]any for flexibility, +// but these structs can be used when code needs structured access. + +// ApiKeyCredentials holds credentials for ApiKey auth type. +type ApiKeyCredentials struct { + Key string `json:"key" yaml:"key"` +} + +// CustomKeysCredentials holds credentials for CustomKeys auth type. +type CustomKeysCredentials struct { + Keys map[string]string `json:"keys" yaml:"keys"` +} + +// OAuth2Credentials holds credentials for OAuth2 auth type. +type OAuth2Credentials struct { + AuthUrl string `json:"authUrl,omitempty" yaml:"authUrl,omitempty"` + ClientId string `json:"clientId" yaml:"clientId"` + ClientSecret string `json:"clientSecret,omitempty" yaml:"clientSecret,omitempty"` + DeveloperToken string `json:"developerToken,omitempty" yaml:"developerToken,omitempty"` + Password string `json:"password,omitempty" yaml:"password,omitempty"` + RefreshToken string `json:"refreshToken,omitempty" yaml:"refreshToken,omitempty"` + TenantId string `json:"tenantId,omitempty" yaml:"tenantId,omitempty"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` +} + +// PATCredentials holds credentials for PAT (Personal Access Token) auth type. +type PATCredentials struct { + Pat string `json:"pat" yaml:"pat"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go index fba8c971244..25ac5cd6f9d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go @@ -23,3 +23,235 @@ func TestArrayProperty_BasicSerialization(t *testing.T) { t.Fatalf("Failed to unmarshal ArrayProperty: %v", err) } } + +// TestToolKindToAPIType tests camelCase to snake_case tool kind conversion +func TestToolKindToAPIType(t *testing.T) { + tests := []struct { + input ToolKind + expected string + }{ + {ToolKindFunction, "function"}, + {ToolKindCustom, "custom"}, + {ToolKindWebSearch, "web_search"}, + {ToolKindBingGrounding, "bing_grounding"}, + {ToolKindFileSearch, "file_search"}, + {ToolKindMcp, "mcp"}, + {ToolKindOpenApi, "openapi"}, + {ToolKindCodeInterpreter, "code_interpreter"}, + {ToolKindAzureAiSearch, "azure_ai_search"}, + {ToolKindA2APreview, "a2a_preview"}, + {ToolKind("unknown"), "unknown"}, + } + + for _, tc := range tests { + t.Run(string(tc.input), func(t *testing.T) { + result := ToolKindToAPIType(tc.input) + if result != tc.expected { + t.Errorf("ToolKindToAPIType(%s) = %s, want %s", tc.input, result, tc.expected) + } + }) + } +} + +// TestAPITypeToToolKind tests snake_case to camelCase tool kind conversion +func TestAPITypeToToolKind(t *testing.T) { + tests := []struct { + input string + expected ToolKind + }{ + {"web_search", ToolKindWebSearch}, + {"bing_grounding", ToolKindBingGrounding}, + {"file_search", ToolKindFileSearch}, + {"mcp", ToolKindMcp}, + {"openapi", ToolKindOpenApi}, + {"code_interpreter", ToolKindCodeInterpreter}, + {"azure_ai_search", ToolKindAzureAiSearch}, + {"a2a_preview", ToolKindA2APreview}, + {"unknown", ToolKind("unknown")}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result := APITypeToToolKind(tc.input) + if result != tc.expected { + t.Errorf("APITypeToToolKind(%s) = %s, want %s", tc.input, result, tc.expected) + } + }) + } +} + +// TestIsValidAuthType tests auth type validation +func TestIsValidAuthType(t *testing.T) { + validTypes := []AuthType{AuthTypeAAD, AuthTypeApiKey, AuthTypeCustomKeys, AuthTypeNone, AuthTypeOAuth2, AuthTypePAT} + for _, at := range validTypes { + if !IsValidAuthType(at) { + t.Errorf("IsValidAuthType(%s) should be true", at) + } + } + + if IsValidAuthType("InvalidType") { + t.Error("IsValidAuthType(InvalidType) should be false") + } +} + +// TestValidateConnectionResource tests connection resource validation +func TestValidateConnectionResource(t *testing.T) { + tests := []struct { + name string + conn ConnectionResource + wantErr bool + }{ + { + name: "valid connection", + conn: ConnectionResource{ + Resource: Resource{Name: "test", Kind: ResourceKindConnection}, + Category: CategoryCustomKeys, + Target: "https://example.com", + AuthType: AuthTypeCustomKeys, + }, + wantErr: false, + }, + { + name: "missing name", + conn: ConnectionResource{ + Category: CategoryCustomKeys, + Target: "https://example.com", + AuthType: AuthTypeCustomKeys, + }, + wantErr: true, + }, + { + name: "missing category", + conn: ConnectionResource{ + Resource: Resource{Name: "test", Kind: ResourceKindConnection}, + Target: "https://example.com", + AuthType: AuthTypeCustomKeys, + }, + wantErr: true, + }, + { + name: "invalid auth type", + conn: ConnectionResource{ + Resource: Resource{Name: "test", Kind: ResourceKindConnection}, + Category: CategoryCustomKeys, + Target: "https://example.com", + AuthType: "BadAuth", + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateConnectionResource(&tc.conn) + if (err != nil) != tc.wantErr { + t.Errorf("ValidateConnectionResource() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} + +// TestConnectionResourceSerialization tests JSON round-trip for ConnectionResource +func TestConnectionResourceSerialization(t *testing.T) { + conn := ConnectionResource{ + Resource: Resource{Name: "test-conn", Kind: ResourceKindConnection}, + Category: CategoryCustomKeys, + Target: "https://example.com", + AuthType: AuthTypeCustomKeys, + Credentials: map[string]any{"key": "secret"}, + Metadata: map[string]string{"env": "test"}, + ExpiryTime: "2025-12-31", + IsSharedToAll: new(true), + } + + data, err := json.Marshal(conn) + if err != nil { + t.Fatalf("Failed to marshal ConnectionResource: %v", err) + } + + var conn2 ConnectionResource + if err := json.Unmarshal(data, &conn2); err != nil { + t.Fatalf("Failed to unmarshal ConnectionResource: %v", err) + } + + if conn2.Name != "test-conn" { + t.Errorf("Expected name 'test-conn', got '%s'", conn2.Name) + } + if conn2.AuthType != AuthTypeCustomKeys { + t.Errorf("Expected authType 'CustomKeys', got '%s'", conn2.AuthType) + } + if conn2.IsSharedToAll == nil || !*conn2.IsSharedToAll { + t.Error("Expected isSharedToAll to be true") + } +} + +// TestAzureAISearchToolSerialization tests JSON round-trip for AzureAISearchTool +func TestAzureAISearchToolSerialization(t *testing.T) { + tool := AzureAISearchTool{ + Tool: Tool{ + Name: "search-tool", + Kind: ToolKindAzureAiSearch, + }, + Indexes: []AzureAISearchIndex{ + { + ProjectConnectionId: "my-conn", + IndexName: "my-index", + TopK: new(5), + }, + }, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("Failed to marshal AzureAISearchTool: %v", err) + } + + var tool2 AzureAISearchTool + if err := json.Unmarshal(data, &tool2); err != nil { + t.Fatalf("Failed to unmarshal AzureAISearchTool: %v", err) + } + + if tool2.Kind != ToolKindAzureAiSearch { + t.Errorf("Expected kind 'azureAiSearch', got '%s'", tool2.Kind) + } + if len(tool2.Indexes) != 1 { + t.Fatalf("Expected 1 index, got %d", len(tool2.Indexes)) + } + if tool2.Indexes[0].IndexName != "my-index" { + t.Errorf("Expected index_name 'my-index', got '%s'", tool2.Indexes[0].IndexName) + } +} + +// TestA2APreviewToolSerialization tests JSON round-trip for A2APreviewTool +func TestA2APreviewToolSerialization(t *testing.T) { + agentCardPath := "/.well-known/agent-card.json" + tool := A2APreviewTool{ + Tool: Tool{ + Name: "a2a-tool", + Kind: ToolKindA2APreview, + }, + BaseUrl: "https://agent.example.com", + AgentCardPath: &agentCardPath, + ProjectConnectionId: "my-conn", + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("Failed to marshal A2APreviewTool: %v", err) + } + + var tool2 A2APreviewTool + if err := json.Unmarshal(data, &tool2); err != nil { + t.Fatalf("Failed to unmarshal A2APreviewTool: %v", err) + } + + if tool2.Kind != ToolKindA2APreview { + t.Errorf("Expected kind 'a2aPreview', got '%s'", tool2.Kind) + } + if tool2.BaseUrl != "https://agent.example.com" { + t.Errorf("Expected baseUrl 'https://agent.example.com', got '%s'", tool2.BaseUrl) + } + if tool2.AgentCardPath == nil || *tool2.AgentCardPath != agentCardPath { + t.Errorf("Expected agentCardPath '%s'", agentCardPath) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config.go b/cli/azd/extensions/azure.ai.agents/internal/project/config.go index 81a9ed1f195..41da94f5222 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/config.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config.go @@ -49,6 +49,7 @@ type ServiceTargetAgentConfig struct { Deployments []Deployment `json:"deployments,omitempty"` Resources []Resource `json:"resources,omitempty"` Toolboxes []Toolbox `json:"toolboxes,omitempty"` + Connections []Connection `json:"connections,omitempty"` StartupCommand string `json:"startupCommand,omitempty"` } @@ -116,6 +117,23 @@ type Toolbox struct { Tools []map[string]any `json:"tools"` } +// Connection represents a project connection matching the Bicep ConnectionPropertiesV2 spec. +type Connection struct { + Name string `json:"name"` + Category string `json:"category"` + Target string `json:"target"` + AuthType string `json:"authType"` + Credentials map[string]any `json:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + ExpiryTime string `json:"expiryTime,omitempty"` + IsSharedToAll *bool `json:"isSharedToAll,omitempty"` + SharedUserList []string `json:"sharedUserList,omitempty"` + PeRequirement string `json:"peRequirement,omitempty"` + PeStatus string `json:"peStatus,omitempty"` + UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty"` + Error string `json:"error,omitempty"` +} + // UnmarshalStruct converts a structpb.Struct to a Go struct of type T func UnmarshalStruct[T any](s *structpb.Struct, out *T) error { structBytes, err := protojson.Marshal(s) diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index 2e9adadb189..4b6f88a78fd 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -436,11 +436,6 @@ func (p *AgentServiceTargetProvider) Deploy( fmt.Println("Loaded custom service target configuration") } - // Deploy toolboxes before agent creation - if err := p.deployToolboxes(ctx, serviceTargetConfig, azdEnv); err != nil { - return nil, err - } - // Load and validate the agent manifest data, err := os.ReadFile(p.agentDefinitionPath) if err != nil { diff --git a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json index 722a3b49359..7365e9f8f1a 100644 --- a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json +++ b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json @@ -29,6 +29,11 @@ "description": "List of toolboxes (Foundry Toolsets) to deploy.", "items": { "$ref": "#/definitions/Toolbox" } }, + "connections": { + "type": "array", + "description": "List of project connections to create via Bicep provisioning.", + "items": { "$ref": "#/definitions/Connection" } + }, "startupCommand": { "type": "string", "description": "Command to start the agent server (e.g., 'python main.py'). Used by 'azd ai agent run' for local development." @@ -139,6 +144,42 @@ }, "required": ["name", "tools"], "additionalProperties": false + }, + "Connection": { + "type": "object", + "description": "A project connection matching the Bicep ConnectionPropertiesV2 spec.", + "properties": { + "name": { "type": "string", "description": "Connection name.", "pattern": "^[a-zA-Z0-9][a-zA-Z0-9_-]{2,32}$" }, + "category": { "type": "string", "description": "Connection category (e.g., 'CustomKeys', 'AzureOpenAI', 'CognitiveSearch')." }, + "target": { "type": "string", "description": "Target endpoint URL for the connection." }, + "authType": { + "type": "string", + "description": "Authentication type.", + "enum": ["AAD", "ApiKey", "CustomKeys", "None", "OAuth2", "PAT"] + }, + "credentials": { + "type": "object", + "description": "Authentication credentials. Structure depends on authType." + }, + "metadata": { + "type": "object", + "description": "Additional metadata as key-value pairs.", + "additionalProperties": { "type": "string" } + }, + "expiryTime": { "type": "string", "description": "Connection expiry time." }, + "isSharedToAll": { "type": "boolean", "description": "Whether the connection is shared to all users." }, + "sharedUserList": { + "type": "array", + "description": "List of users the connection is shared with.", + "items": { "type": "string" } + }, + "peRequirement": { "type": "string", "description": "Private endpoint requirement." }, + "peStatus": { "type": "string", "description": "Private endpoint status." }, + "useWorkspaceManagedIdentity": { "type": "boolean", "description": "Whether to use workspace managed identity." }, + "error": { "type": "string", "description": "Error information." } + }, + "required": ["name", "category", "target", "authType"], + "additionalProperties": false } } } \ No newline at end of file From 1c7f43371129537353e37247e528f46092fb1667 Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 08:23:51 -0700 Subject: [PATCH 03/14] Clean up dead/duplicated code after toolbox merge - Remove ToolboxToolDefinition struct (replaced by []any tools) - Remove deriveConnectionName (orphaned after init refactor) - Migrate ToolKind constants to snake_case, remove converter functions - Remove 7 dead functions from service_target_agent.go (deployToolboxes, enrichToolboxFromConnections, resolveToolboxEnvironmentVariables, resolveMapValues, resolveAnyValue, upsertToolset, registerToolboxEnvironmentVariables) and their tests - Fix provision path: correct FOUNDRY_TOOLBOX_* env var prefix, add env var resolution and connection enrichment, update-on-conflict for upsertToolset (409 -> update instead of skip) - Extract marshalAndSetEnvVar shared helper to reduce duplication - Consolidate toolboxMCPEndpointEnvKey (remove duplicate from init.go) - Fix GetValues API call to use correct GetEnvironmentRequest type Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../azure.ai.agents/internal/cmd/init.go | 182 ++++++- .../azure.ai.agents/internal/cmd/init_test.go | 456 ++++++++++++++++++ .../azure.ai.agents/internal/cmd/listen.go | 212 +++++++- .../internal/pkg/agents/agent_yaml/parse.go | 3 + .../pkg/agents/agent_yaml/parse_test.go | 309 +++++++++++- .../internal/pkg/agents/agent_yaml/yaml.go | 316 ++++++++++-- .../pkg/agents/agent_yaml/yaml_test.go | 60 +-- .../pkg/agents/registry_api/helpers.go | 26 +- .../internal/project/config.go | 26 +- .../internal/project/config_test.go | 99 ++++ .../internal/project/service_target_agent.go | 122 +---- .../project/service_target_agent_test.go | 27 ++ .../schemas/azure.ai.agent.json | 30 ++ 13 files changed, 1612 insertions(+), 256 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index 5f3c72ebf7f..7f18812079f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -418,6 +418,18 @@ func (a *InitAction) Run(ctx context.Context) error { return fmt.Errorf("configuring model choice: %w", err) } + // Prompt for manifest parameters (e.g. tool credentials) after project selection + agentManifest, err = registry_api.ProcessManifestParameters( + ctx, agentManifest, a.azdClient, a.flags.NoPrompt, + ) + if err != nil { + return fmt.Errorf("failed to process manifest parameters: %w", err) + } + + // Inject toolbox MCP endpoint env vars into hosted agent definitions + // so agent.yaml is self-documenting about what env vars will be set. + injectToolboxEnvVarsIntoDefinition(agentManifest) + // Write the final agent.yaml to disk (after deployment names have been injected) if err := writeAgentDefinitionFile(targetDir, agentManifest); err != nil { return fmt.Errorf("writing agent definition: %w", err) @@ -995,11 +1007,6 @@ func (a *InitAction) downloadAgentYaml( } } - agentManifest, err = registry_api.ProcessManifestParameters(ctx, agentManifest, a.azdClient, a.flags.NoPrompt) - if err != nil { - return nil, "", fmt.Errorf("failed to process manifest parameters: %w", err) - } - _, isPromptAgent := agentManifest.Template.(agent_yaml.PromptAgent) if isPromptAgent { agentManifest, err = agent_yaml.ProcessPromptAgentToolsConnections(ctx, agentManifest, a.azdClient) @@ -1164,11 +1171,24 @@ func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentMa agentConfig.Resources = resourceDetails // Process toolbox resources from the manifest - toolboxes, err := extractToolboxConfigs(agentManifest) + toolboxes, toolConnections, credEnvVars, err := extractToolboxAndConnectionConfigs(agentManifest) if err != nil { return err } agentConfig.Toolboxes = toolboxes + agentConfig.ToolConnections = toolConnections + + // Persist credential values as azd environment variables so they are + // resolved at provision/deploy time instead of stored in azure.yaml. + for envKey, envVal := range credEnvVars { + if _, setErr := a.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: a.environment.Name, + Key: envKey, + Value: envVal, + }); setErr != nil { + return fmt.Errorf("storing credential env var %s: %w", envKey, setErr) + } + } // Process connection resources from the manifest connections, err := extractConnectionConfigs(agentManifest) @@ -1581,14 +1601,22 @@ func downloadDirectoryContentsWithoutGhCli( return nil } -// extractToolboxConfigs extracts toolbox resource definitions from the agent manifest -// and converts them into project.Toolbox config entries. -func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbox, error) { +// extractToolboxAndConnectionConfigs extracts toolbox resource definitions from the agent manifest +// and converts them into project.Toolbox config entries and project.ToolConnection entries. +// Tools with a target/authType also produce connection entries for Bicep provisioning. +// Built-in tools (bing_grounding, azure_ai_search, etc.) produce toolbox tools but no connections. +func extractToolboxAndConnectionConfigs( + manifest *agent_yaml.AgentManifest, +) ([]project.Toolbox, []project.ToolConnection, map[string]string, error) { if manifest == nil || manifest.Resources == nil { - return nil, nil + return nil, nil, nil, nil } var toolboxes []project.Toolbox + var connections []project.ToolConnection + // credentialEnvVars maps generated env var names to their raw values so + // the caller can persist them in the azd environment. + credentialEnvVars := map[string]string{} for _, resource := range manifest.Resources { tbResource, ok := resource.(agent_yaml.ToolboxResource) @@ -1596,33 +1624,155 @@ func extractToolboxConfigs(manifest *agent_yaml.AgentManifest) ([]project.Toolbo continue } + description := tbResource.Description + if len(tbResource.Tools) == 0 { - return nil, fmt.Errorf( + return nil, nil, nil, fmt.Errorf( "toolbox resource '%s' is missing required 'tools'", tbResource.Name, ) } - tools := make([]map[string]any, 0, len(tbResource.Tools)) + var tools []map[string]any for _, rawTool := range tbResource.Tools { toolMap, ok := rawTool.(map[string]any) if !ok { - return nil, fmt.Errorf( + return nil, nil, nil, fmt.Errorf( "toolbox resource '%s' has invalid tool entry: expected object", tbResource.Name, ) } - tools = append(tools, toolMap) + + // Manifest uses "id" for tool kind; API uses "type" + toolId, _ := toolMap["id"].(string) + + target, _ := toolMap["target"].(string) + if target == "" { + // No target — either a built-in tool or a pre-configured tool + // that already has project_connection_id. Convert "id" → "type" + // and pass through all existing fields. + result := make(map[string]any, len(toolMap)) + for k, v := range toolMap { + result[k] = v + } + if _, hasType := result["type"]; !hasType { + result["type"] = toolId + delete(result, "id") + } + tools = append(tools, result) + continue + } + + // External tools with target/authType need a connection + toolName, _ := toolMap["name"].(string) + authType, _ := toolMap["authType"].(string) + credentials, _ := toolMap["credentials"].(map[string]any) + + connName := toolName + if connName == "" { + connName = tbResource.Name + "-" + toolId + } + + conn := project.ToolConnection{ + Name: connName, + Category: "RemoteTool", + Target: target, + AuthType: authType, + } + + // Extract credentials, storing raw values as env vars and + // replacing them with ${VAR} references in the config. + if len(credentials) > 0 { + creds := make(map[string]any, len(credentials)) + for k, v := range credentials { + envVar := credentialEnvVarName(connName, k) + credentialEnvVars[envVar] = fmt.Sprintf("%v", v) + creds[k] = fmt.Sprintf("${%s}", envVar) + } + + // CustomKeys ARM type requires credentials nested under "keys" + if authType == "CustomKeys" { + conn.Credentials = map[string]any{"keys": creds} + } else { + conn.Credentials = creds + } + } + + connections = append(connections, conn) + + // Toolbox tool entry is minimal — deploy enriches from connection + tool := map[string]any{ + "type": toolId, + "project_connection_id": connName, + } + tools = append(tools, tool) } toolboxes = append(toolboxes, project.Toolbox{ Name: tbResource.Name, - Description: tbResource.Description, + Description: description, Tools: tools, }) } - return toolboxes, nil + return toolboxes, connections, credentialEnvVars, nil +} + +// credentialEnvVarName builds a deterministic env var name for a connection +// credential key, e.g. ("github-copilot", "clientSecret") → "FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET". +func credentialEnvVarName(connName, key string) string { + s := "FOUNDRY_TOOL_" + strings.ToUpper(connName) + "_" + strings.ToUpper(key) + return strings.ReplaceAll(s, "-", "_") +} + +// injectToolboxEnvVarsIntoDefinition adds FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT entries +// to the environment_variables section of a hosted agent definition for each toolbox +// resource in the manifest. Entries already present in the definition are not overwritten. +func injectToolboxEnvVarsIntoDefinition(manifest *agent_yaml.AgentManifest) { + if manifest == nil || manifest.Resources == nil { + return + } + + containerAgent, ok := manifest.Template.(agent_yaml.ContainerAgent) + if !ok { + return + } + + // Collect toolbox resource names + var toolboxNames []string + for _, resource := range manifest.Resources { + if tbResource, ok := resource.(agent_yaml.ToolboxResource); ok { + toolboxNames = append(toolboxNames, tbResource.Name) + } + } + if len(toolboxNames) == 0 { + return + } + + if containerAgent.EnvironmentVariables == nil { + envVars := []agent_yaml.EnvironmentVariable{} + containerAgent.EnvironmentVariables = &envVars + } + + existingNames := make(map[string]bool, len(*containerAgent.EnvironmentVariables)) + for _, ev := range *containerAgent.EnvironmentVariables { + existingNames[ev.Name] = true + } + + for _, tbName := range toolboxNames { + envKey := toolboxMCPEndpointEnvKey(tbName) + if !existingNames[envKey] { + *containerAgent.EnvironmentVariables = append( + *containerAgent.EnvironmentVariables, + agent_yaml.EnvironmentVariable{ + Name: envKey, + Value: fmt.Sprintf("${%s}", envKey), + }, + ) + } + } + + manifest.Template = containerAgent } // extractConnectionConfigs extracts connection resource definitions from the agent manifest diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go index 6fb935235a3..eacb211a10f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go @@ -9,6 +9,8 @@ import ( "path/filepath" "testing" + "azureaiagent/internal/pkg/agents/agent_yaml" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" @@ -394,3 +396,457 @@ func TestParseGitHubUrlNaive(t *testing.T) { }) } } + +func TestExtractToolboxAndConnectionConfigs_TypedTools(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "platform-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Description: "Platform tools", + Tools: []any{ + map[string]any{ + // Built-in tool — no connection + "id": "bing_grounding", + }, + map[string]any{ + // External tool with name — connection name from Name field + "id": "mcp", + "name": "github-copilot", + "target": "https://api.githubcopilot.com/mcp", + "authType": "OAuth2", + "credentials": map[string]any{ + "clientId": "my-client-id", + "clientSecret": "my-secret", + }, + }, + }, + }, + }, + } + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + // Only the external tool creates a connection (not bing_grounding) + if len(connections) != 1 { + t.Fatalf("Expected 1 connection, got %d", len(connections)) + } + conn := connections[0] + if conn.Name != "github-copilot" { + t.Errorf("Expected connection name 'github-copilot', got '%s'", conn.Name) + } + if conn.Category != "RemoteTool" { + t.Errorf("Expected category 'RemoteTool', got '%s'", conn.Category) + } + if conn.Target != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected target, got '%s'", conn.Target) + } + if conn.AuthType != "OAuth2" { + t.Errorf("Expected authType 'OAuth2', got '%s'", conn.AuthType) + } + + // Credentials should be ${VAR} references, not raw values + if conn.Credentials["clientId"] != "${FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID}" { + t.Errorf("Expected env var ref for clientId, got '%v'", conn.Credentials["clientId"]) + } + if conn.Credentials["clientSecret"] != "${FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET}" { + t.Errorf("Expected env var ref for clientSecret, got '%v'", conn.Credentials["clientSecret"]) + } + + // Raw values should be in the credEnvVars map + if credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID"] != "my-client-id" { + t.Errorf("Expected env var value 'my-client-id', got '%s'", + credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID"]) + } + if credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET"] != "my-secret" { + t.Errorf("Expected env var value 'my-secret', got '%s'", + credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET"]) + } + + // Verify toolbox has both tools + if len(toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(toolboxes)) + } + tb := toolboxes[0] + if tb.Name != "platform-tools" { + t.Errorf("Expected toolbox name 'platform-tools', got '%s'", tb.Name) + } + if tb.Description != "Platform tools" { + t.Errorf("Expected description 'Platform tools', got '%s'", tb.Description) + } + if len(tb.Tools) != 2 { + t.Fatalf("Expected 2 tools, got %d", len(tb.Tools)) + } + + // First tool: built-in (no project_connection_id) + if tb.Tools[0]["type"] != "bing_grounding" { + t.Errorf("Expected tool[0] type 'bing_grounding', got '%v'", tb.Tools[0]["type"]) + } + if _, hasConn := tb.Tools[0]["project_connection_id"]; hasConn { + t.Errorf("Built-in tool should not have project_connection_id") + } + + // Second tool: minimal (type + project_connection_id only) + if tb.Tools[1]["project_connection_id"] != "github-copilot" { + t.Errorf("Expected project_connection_id 'github-copilot', got '%v'", + tb.Tools[1]["project_connection_id"]) + } + if tb.Tools[1]["type"] != "mcp" { + t.Errorf("Expected tool type 'mcp', got '%v'", tb.Tools[1]["type"]) + } + // No server_url or server_label in init output — deploy enriches from connections + if _, has := tb.Tools[1]["server_url"]; has { + t.Errorf("Toolbox tool should not have server_url (deploy enriches it)") + } + if _, has := tb.Tools[1]["server_label"]; has { + t.Errorf("Toolbox tool should not have server_label (deploy enriches it)") + } +} + +func TestExtractToolboxAndConnectionConfigs_RawToolsFallback(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "raw-toolbox", + Kind: agent_yaml.ResourceKindToolbox, + }, + Description: "Raw tools", + Tools: []any{ + map[string]any{ + "id": "mcp", + "name": "existing", + "project_connection_id": "existing-conn", + }, + }, + }, + }, + } + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + // No connections or env vars extracted from raw tools + if len(connections) != 0 { + t.Errorf("Expected 0 connections, got %d", len(connections)) + } + if len(credEnvVars) != 0 { + t.Errorf("Expected 0 env vars, got %d", len(credEnvVars)) + } + + if len(toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(toolboxes)) + } + if toolboxes[0].Tools[0]["project_connection_id"] != "existing-conn" { + t.Errorf("Expected 'existing-conn', got '%v'", toolboxes[0].Tools[0]["project_connection_id"]) + } +} + +func TestExtractToolboxAndConnectionConfigs_NilManifest(t *testing.T) { + t.Parallel() + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(nil) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if toolboxes != nil { + t.Errorf("Expected nil toolboxes, got %v", toolboxes) + } + if connections != nil { + t.Errorf("Expected nil connections, got %v", connections) + } + if credEnvVars != nil { + t.Errorf("Expected nil env vars, got %v", credEnvVars) + } +} + +func TestExtractToolboxAndConnectionConfigs_CustomKeysCredentials(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "my-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{ + "id": "mcp", + "name": "custom-api", + "target": "https://example.com/mcp", + "authType": "CustomKeys", + "credentials": map[string]any{"key": "my-api-key"}, + }, + map[string]any{ + "id": "mcp", + "name": "oauth-tool", + "target": "https://example.com/oauth", + "authType": "OAuth2", + "credentials": map[string]any{"clientId": "id", "clientSecret": "secret"}, + }, + }, + }, + }, + } + + _, connections, _, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + if len(connections) != 2 { + t.Fatalf("Expected 2 connections, got %d", len(connections)) + } + + // CustomKeys: credentials must be nested under "keys" + customConn := connections[0] + keysRaw, ok := customConn.Credentials["keys"] + if !ok { + t.Fatal("CustomKeys connection missing 'keys' wrapper in credentials") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected 'keys' to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${FOUNDRY_TOOL_CUSTOM_API_KEY}" { + t.Errorf("Expected env var ref for key, got '%v'", keys["key"]) + } + + // OAuth2: credentials should be flat (no "keys" wrapper) + oauthConn := connections[1] + if _, hasKeys := oauthConn.Credentials["keys"]; hasKeys { + t.Error("OAuth2 connection should not have 'keys' wrapper") + } + if oauthConn.Credentials["clientId"] != "${FOUNDRY_TOOL_OAUTH_TOOL_CLIENTID}" { + t.Errorf("Expected flat clientId ref, got '%v'", oauthConn.Credentials["clientId"]) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_AddsEnvVars(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "v1"}, + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "agent-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{"id": "bing_grounding"}, + }, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + if len(envVars) != 2 { + t.Fatalf("Expected 2 env vars, got %d", len(envVars)) + } + + // Original env var is preserved + if envVars[0].Name != "AZURE_OPENAI_ENDPOINT" { + t.Errorf("Expected first env var to be AZURE_OPENAI_ENDPOINT, got %s", envVars[0].Name) + } + + // Toolbox env var is injected + if envVars[1].Name != "FOUNDRY_TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected injected env var name, got %s", envVars[1].Name) + } + if envVars[1].Value != "${FOUNDRY_TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT}" { + t.Errorf("Expected env var reference value, got %s", envVars[1].Value) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_SkipsExisting(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "v1"}, + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT", Value: "custom-value"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "my-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{"id": "bing_grounding"}, + }, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + // Should not add a duplicate — user's value is preserved + if len(envVars) != 1 { + t.Fatalf("Expected 1 env var (no duplicate), got %d", len(envVars)) + } + if envVars[0].Value != "custom-value" { + t.Errorf("Expected user value preserved, got %s", envVars[0].Value) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_MultipleToolboxes(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "v1"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "search-tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"id": "bing_grounding"}}, + }, + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "github-tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"id": "mcp", "target": "https://example.com"}}, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + if len(envVars) != 2 { + t.Fatalf("Expected 2 env vars, got %d", len(envVars)) + } + if envVars[0].Name != "FOUNDRY_TOOLBOX_SEARCH_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected first toolbox env var, got %s", envVars[0].Name) + } + if envVars[1].Name != "FOUNDRY_TOOLBOX_GITHUB_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected second toolbox env var, got %s", envVars[1].Name) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopForNilManifest(t *testing.T) { + t.Parallel() + + // Should not panic + injectToolboxEnvVarsIntoDefinition(nil) +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopForPromptAgent(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.PromptAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + Name: "prompt-agent", + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"id": "bing_grounding"}}, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + // Template should be unchanged (still a PromptAgent, no EnvironmentVariables field) + if _, ok := manifest.Template.(agent_yaml.PromptAgent); !ok { + t.Error("Expected template to remain a PromptAgent") + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopWithoutToolboxes(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, + }, + }, + Resources: []any{}, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + if len(*containerAgent.EnvironmentVariables) != 1 { + t.Errorf("Expected env vars unchanged, got %d", len(*containerAgent.EnvironmentVariables)) + } +} + +func TestToolboxMCPEndpointEnvKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + {"simple", "my-tools", "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"spaces", "my tools", "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"mixed", "agent-tools v2", "FOUNDRY_TOOLBOX_AGENT_TOOLS_V2_MCP_ENDPOINT"}, + {"already upper", "TOOLS", "FOUNDRY_TOOLBOX_TOOLS_MCP_ENDPOINT"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := toolboxMCPEndpointEnvKey(tt.input) + if got != tt.expected { + t.Errorf("toolboxMCPEndpointEnvKey(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index e2d3b473b3d..a4ab4e33022 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" + "github.com/drone/envsubst" "github.com/spf13/cobra" "google.golang.org/protobuf/types/known/structpb" ) @@ -191,6 +192,15 @@ func envUpdate(ctx context.Context, azdClient *azdext.AzdClient, azdProject *azd } } + if len(foundryAgentConfig.ToolConnections) > 0 { + if err := toolConnectionsEnvUpdate( + ctx, foundryAgentConfig.ToolConnections, + azdClient, currentEnvResponse.Environment.Name, + ); err != nil { + return err + } + } + return nil } @@ -264,17 +274,62 @@ func connectionsEnvUpdate( azdClient *azdext.AzdClient, envName string, ) error { - connectionsJson, err := json.Marshal(connections) + return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", connections) +} + +// toolConnectionsEnvUpdate serializes tool connections to AI_PROJECT_TOOL_CONNECTIONS env var. +func toolConnectionsEnvUpdate( + ctx context.Context, + connections []project.ToolConnection, + azdClient *azdext.AzdClient, + envName string, +) error { + // Normalize credentials before serializing: CustomKeys authType requires + // credentials nested under "keys" for the ARM API. + normalized := make([]project.ToolConnection, len(connections)) + copy(normalized, connections) + for i := range normalized { + normalized[i].Credentials = normalizeCredentials(normalized[i].AuthType, normalized[i].Credentials) + } + + return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_TOOL_CONNECTIONS", normalized) +} + +// marshalAndSetEnvVar serializes a value to JSON, escapes it for safe storage +// in an azd environment variable, and persists it. +func marshalAndSetEnvVar( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, + key string, + value any, +) error { + data, err := json.Marshal(value) if err != nil { - return fmt.Errorf("failed to marshal connection details to JSON: %w", err) + return fmt.Errorf("failed to marshal %s to JSON: %w", key, err) } - // Escape backslashes and double quotes for environment variable - jsonString := string(connectionsJson) - escapedJsonString := strings.ReplaceAll(jsonString, "\\", "\\\\") - escapedJsonString = strings.ReplaceAll(escapedJsonString, "\"", "\\\"") + jsonString := string(data) + escaped := strings.ReplaceAll(jsonString, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "\"", "\\\"") + + return setEnvVar(ctx, azdClient, envName, key, escaped) +} - return setEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", escapedJsonString) +// normalizeCredentials ensures credentials match the expected ARM format. +// CustomKeys requires credentials nested under "keys": { "keys": { "key": "val" } }. +// If already wrapped, returns as-is. Other auth types are returned unchanged. +func normalizeCredentials(authType string, creds map[string]any) map[string]any { + if authType != "CustomKeys" || len(creds) == 0 { + return creds + } + + // Already correctly wrapped + if _, hasKeys := creds["keys"]; hasKeys && len(creds) == 1 { + return creds + } + + return map[string]any{"keys": creds} } func containerAgentHandling(ctx context.Context, azdClient *azdext.AzdClient, project *azdext.ProjectConfig, svc *azdext.ServiceConfig) error { @@ -462,11 +517,29 @@ func provisionToolboxes( projectEndpoint, cred, ) + // Build azd env lookup for resolving ${VAR} references in tool entries + azdEnv, err := getAllEnvVars(ctx, azdClient, currentEnv.Environment.Name) + if err != nil { + return fmt.Errorf("failed to load environment variables: %w", err) + } + + // Build connection lookup for enriching tool entries with server_url/server_label + connByName := map[string]project.ToolConnection{} + for _, c := range config.ToolConnections { + connByName[c.Name] = c + } + for _, toolbox := range config.Toolboxes { fmt.Fprintf( os.Stderr, "Provisioning toolbox: %s\n", toolbox.Name, ) + // Resolve ${VAR} references in tool map values before sending to API + resolveToolboxEnvVars(&toolbox, azdEnv) + + // Fill in server_url/server_label from connection data + enrichToolboxFromConnections(&toolbox, connByName) + if err := upsertToolset( ctx, toolsetsClient, toolbox, ); err != nil { @@ -489,7 +562,7 @@ func provisionToolboxes( return nil } -// upsertToolset creates a toolset, or skips if it already exists. +// upsertToolset creates a toolset, or updates it if it already exists. func upsertToolset( ctx context.Context, client *azure.FoundryToolsetsClient, @@ -506,14 +579,24 @@ func upsertToolset( return nil } - // 409 Conflict means the toolset already exists + // 409 Conflict means the toolset already exists — update it var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { fmt.Fprintf( os.Stderr, - " Toolset '%s' already exists, skipping\n", + " Toolset '%s' already exists, updating...\n", toolbox.Name, ) + updateReq := &azure.UpdateToolsetRequest{ + Description: toolbox.Description, + Tools: toolbox.Tools, + } + if _, updateErr := client.UpdateToolset(ctx, toolbox.Name, updateReq); updateErr != nil { + return exterrors.Internal( + exterrors.CodeCreateToolsetFailed, + fmt.Sprintf("failed to update toolset '%s': %s", toolbox.Name, updateErr), + ) + } return nil } @@ -526,7 +609,7 @@ func upsertToolset( ) } -// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT. +// registerToolboxEnvVars sets FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT. func registerToolboxEnvVars( ctx context.Context, azdClient *azdext.AzdClient, @@ -534,12 +617,7 @@ func registerToolboxEnvVars( projectEndpoint string, toolboxName string, ) error { - key := strings.ToUpper( - strings.ReplaceAll(toolboxName, "-", "_"), - ) - envKey := fmt.Sprintf( - "TOOLBOX_%s_MCP_ENDPOINT", key, - ) + envKey := toolboxMCPEndpointEnvKey(toolboxName) endpoint := strings.TrimRight(projectEndpoint, "/") mcpEndpoint := fmt.Sprintf( @@ -550,3 +628,103 @@ func registerToolboxEnvVars( ctx, azdClient, envName, envKey, mcpEndpoint, ) } + +// toolboxMCPEndpointEnvKey builds the FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT env var key. +func toolboxMCPEndpointEnvKey(toolboxName string) string { + key := strings.ReplaceAll(toolboxName, " ", "_") + key = strings.ReplaceAll(key, "-", "_") + return fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", strings.ToUpper(key)) +} + +// resolveToolboxEnvVars resolves ${VAR} references in toolbox name, description, +// and all tool map values using the provided azd environment variables. +func resolveToolboxEnvVars(toolbox *project.Toolbox, azdEnv map[string]string) { + toolbox.Name = resolveEnvValue(toolbox.Name, azdEnv) + toolbox.Description = resolveEnvValue(toolbox.Description, azdEnv) + for i, tool := range toolbox.Tools { + toolbox.Tools[i] = resolveMapValues(tool, azdEnv) + } +} + +// enrichToolboxFromConnections fills in server_url and server_label on toolbox +// tools that reference a connection via project_connection_id. This keeps the +// azure.yaml toolbox entries minimal while sending complete data to the API. +func enrichToolboxFromConnections( + toolbox *project.Toolbox, + connByName map[string]project.ToolConnection, +) { + for i, tool := range toolbox.Tools { + connID, _ := tool["project_connection_id"].(string) + if connID == "" { + continue + } + conn, ok := connByName[connID] + if !ok { + continue + } + if _, has := tool["server_url"]; !has && conn.Target != "" { + toolbox.Tools[i]["server_url"] = conn.Target + } + if _, has := tool["server_label"]; !has { + toolbox.Tools[i]["server_label"] = conn.Name + } + } +} + +// resolveEnvValue resolves ${VAR} references in a string using envsubst. +func resolveEnvValue(value string, azdEnv map[string]string) string { + resolved, err := envsubst.Eval(value, func(varName string) string { + return azdEnv[varName] + }) + if err != nil { + return value + } + return resolved +} + +// resolveMapValues recursively resolves ${VAR} references in all string values of a map. +func resolveMapValues(m map[string]any, azdEnv map[string]string) map[string]any { + resolved := make(map[string]any, len(m)) + for k, v := range m { + resolved[k] = resolveAnyValue(v, azdEnv) + } + return resolved +} + +// resolveAnyValue resolves ${VAR} references in a value of any type. +func resolveAnyValue(v any, azdEnv map[string]string) any { + switch val := v.(type) { + case string: + return resolveEnvValue(val, azdEnv) + case map[string]any: + return resolveMapValues(val, azdEnv) + case []any: + resolved := make([]any, len(val)) + for i, item := range val { + resolved[i] = resolveAnyValue(item, azdEnv) + } + return resolved + default: + return v + } +} + +// getAllEnvVars loads all environment variables from the azd environment. +func getAllEnvVars( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, +) (map[string]string, error) { + resp, err := azdClient.Environment().GetValues(ctx, &azdext.GetEnvironmentRequest{ + Name: envName, + }) + if err != nil { + return nil, err + } + + envMap := make(map[string]string, len(resp.KeyValues)) + for _, kv := range resp.KeyValues { + envMap[kv.Key] = kv.Value + } + return envMap, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go index 3048bed1f6e..0a56af702f4 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go @@ -181,6 +181,9 @@ func ExtractResourceDefinitions(manifestYamlContent []byte) ([]any, error) { if err := yaml.Unmarshal(resourceBytes, &connDef); err != nil { return nil, fmt.Errorf("failed to unmarshal to ConnectionResource: %w", err) } + if err := ValidateConnectionResource(&connDef); err != nil { + return nil, err + } resourceDefs = append(resourceDefs, connDef) default: return nil, fmt.Errorf("unrecognized resource kind: %s", resourceDef.Kind) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index c0c4a97fe52..a07db730013 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -604,13 +604,13 @@ resources: } } -// TestExtractToolsDefinitions_AzureAiSearch tests parsing an azureAiSearch tool +// TestExtractToolsDefinitions_AzureAiSearch tests parsing an azure_ai_search tool func TestExtractToolsDefinitions_AzureAiSearch(t *testing.T) { template := map[string]any{ "tools": []any{ map[string]any{ "name": "my-search", - "kind": "azureAiSearch", + "kind": "azure_ai_search", "indexes": []any{ map[string]any{ "project_connection_id": "search-conn", @@ -637,7 +637,7 @@ func TestExtractToolsDefinitions_AzureAiSearch(t *testing.T) { t.Fatalf("Expected AzureAISearchTool, got %T", tools[0]) } if searchTool.Kind != ToolKindAzureAiSearch { - t.Errorf("Expected kind 'azureAiSearch', got '%s'", searchTool.Kind) + t.Errorf("Expected kind 'azure_ai_search', got '%s'", searchTool.Kind) } if len(searchTool.Indexes) != 1 { t.Fatalf("Expected 1 index, got %d", len(searchTool.Indexes)) @@ -647,13 +647,13 @@ func TestExtractToolsDefinitions_AzureAiSearch(t *testing.T) { } } -// TestExtractToolsDefinitions_A2APreview tests parsing an a2aPreview tool +// TestExtractToolsDefinitions_A2APreview tests parsing an a2a_preview tool func TestExtractToolsDefinitions_A2APreview(t *testing.T) { template := map[string]any{ "tools": []any{ map[string]any{ "name": "a2a-delegate", - "kind": "a2aPreview", + "kind": "a2a_preview", "baseUrl": "https://remote-agent.example.com", "agentCardPath": "/.well-known/agent.json", "projectConnectionId": "remote-conn", @@ -675,7 +675,7 @@ func TestExtractToolsDefinitions_A2APreview(t *testing.T) { t.Fatalf("Expected A2APreviewTool, got %T", tools[0]) } if a2aTool.Kind != ToolKindA2APreview { - t.Errorf("Expected kind 'a2aPreview', got '%s'", a2aTool.Kind) + t.Errorf("Expected kind 'a2a_preview', got '%s'", a2aTool.Kind) } if a2aTool.BaseUrl != "https://remote-agent.example.com" { t.Errorf("Expected baseUrl, got '%s'", a2aTool.BaseUrl) @@ -684,3 +684,300 @@ func TestExtractToolsDefinitions_A2APreview(t *testing.T) { t.Errorf("Expected projectConnectionId 'remote-conn', got '%s'", a2aTool.ProjectConnectionId) } } + +// TestExtractResourceDefinitions_ToolboxResourceWithTypedTools tests parsing a toolbox +// resource that has tool entries in the Tools []any field, +// matching the AgentSchema ToolboxResource/ToolboxTool format. +func TestExtractResourceDefinitions_ToolboxResourceWithTypedTools(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: toolbox + name: platform-tools + description: Platform tools with typed definitions + tools: + - id: bing_grounding + - id: mcp + name: github-copilot + target: https://api.githubcopilot.com/mcp + authType: OAuth2 + credentials: + clientId: my-client-id + clientSecret: my-client-secret + - id: mcp + name: custom-api + target: https://my-api.example.com/sse + authType: CustomKeys + credentials: + key: my-api-key +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + toolboxRes, ok := resources[0].(ToolboxResource) + if !ok { + t.Fatalf("Expected ToolboxResource, got %T", resources[0]) + } + + if toolboxRes.Name != "platform-tools" { + t.Errorf("Expected name 'platform-tools', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Description != "Platform tools with typed definitions" { + t.Errorf("Expected description, got '%s'", toolboxRes.Description) + } + + if len(toolboxRes.Tools) != 3 { + t.Fatalf("Expected 3 typed tools, got %d", len(toolboxRes.Tools)) + } + + // Helper to get tool as map + tool := func(i int) map[string]any { + m, ok := toolboxRes.Tools[i].(map[string]any) + if !ok { + t.Fatalf("Expected tool[%d] to be map[string]any, got %T", i, toolboxRes.Tools[i]) + } + return m + } + + // Check built-in tool (no target/authType/name) + if tool(0)["id"] != "bing_grounding" { + t.Errorf("Expected first tool id 'bing_grounding', got '%v'", tool(0)["id"]) + } + if tool(0)["target"] != nil { + t.Errorf("Expected no target for built-in tool, got '%v'", tool(0)["target"]) + } + + // Check MCP tool with name and OAuth2 + if tool(1)["id"] != "mcp" { + t.Errorf("Expected second tool id 'mcp', got '%v'", tool(1)["id"]) + } + if tool(1)["name"] != "github-copilot" { + t.Errorf("Expected second tool name 'github-copilot', got '%v'", tool(1)["name"]) + } + if tool(1)["target"] != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected second tool target, got '%v'", tool(1)["target"]) + } + if tool(1)["authType"] != "OAuth2" { + t.Errorf("Expected second tool authType 'OAuth2', got '%v'", tool(1)["authType"]) + } + creds1, _ := tool(1)["credentials"].(map[string]any) + if creds1["clientId"] != "my-client-id" { + t.Errorf("Expected second tool clientId, got '%v'", creds1["clientId"]) + } + + // Check MCP tool with CustomKeys + if tool(2)["id"] != "mcp" { + t.Errorf("Expected third tool id 'mcp', got '%v'", tool(2)["id"]) + } + if tool(2)["name"] != "custom-api" { + t.Errorf("Expected third tool name 'custom-api', got '%v'", tool(2)["name"]) + } + if tool(2)["authType"] != "CustomKeys" { + t.Errorf("Expected third tool authType 'CustomKeys', got '%v'", tool(2)["authType"]) + } +} + +// TestLoadAndValidateAgentManifest_RecordFormatParameters verifies that the +// record/map format for parameters (canonical agent manifest schema) is parsed +// correctly into PropertySchema.Properties. +func TestLoadAndValidateAgentManifest_RecordFormatParameters(t *testing.T) { + yamlContent := []byte(` +name: test-params +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 + - kind: toolbox + name: tools + tools: + - id: mcp + name: github + target: https://api.githubcopilot.com/mcp + authType: OAuth2 + credentials: + clientId: "{{ github_client_id }}" + clientSecret: "{{ github_client_secret }}" +parameters: + github_client_id: + schema: + type: string + description: OAuth client ID + required: true + github_client_secret: + schema: + type: string + description: OAuth client secret + required: true + model_name: + schema: + type: string + enum: + - gpt-4o + - gpt-4o-mini + default: gpt-4o + required: true +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 3 { + t.Fatalf("Expected 3 parameters, got %d", len(manifest.Parameters.Properties)) + } + + // Find parameters by name (map order is not guaranteed) + paramsByName := map[string]Property{} + for _, p := range manifest.Parameters.Properties { + paramsByName[p.Name] = p + } + + // Check github_client_id + p, ok := paramsByName["github_client_id"] + if !ok { + t.Fatal("Missing parameter github_client_id") + } + if p.Kind != "string" { + t.Errorf("Expected kind 'string', got '%s'", p.Kind) + } + if p.Description == nil || *p.Description != "OAuth client ID" { + t.Errorf("Unexpected description: %v", p.Description) + } + if p.Required == nil || !*p.Required { + t.Error("Expected required=true") + } + + // Check model_name with enum and default + p, ok = paramsByName["model_name"] + if !ok { + t.Fatal("Missing parameter model_name") + } + if p.EnumValues == nil || len(*p.EnumValues) != 2 { + t.Fatalf("Expected 2 enum values, got %v", p.EnumValues) + } + if p.Default == nil { + t.Fatal("Expected default value") + } + if defaultStr, ok := (*p.Default).(string); !ok || defaultStr != "gpt-4o" { + t.Errorf("Expected default 'gpt-4o', got %v", *p.Default) + } +} + +// TestLoadAndValidateAgentManifest_ArrayFormatParameters verifies that the +// traditional array format for parameters still works after the UnmarshalYAML change. +func TestLoadAndValidateAgentManifest_ArrayFormatParameters(t *testing.T) { + yamlContent := []byte(` +name: test-array-params +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 +parameters: + properties: + - name: my_param + kind: string + description: A test parameter + required: true +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 1 { + t.Fatalf("Expected 1 parameter, got %d", len(manifest.Parameters.Properties)) + } + + p := manifest.Parameters.Properties[0] + if p.Name != "my_param" { + t.Errorf("Expected name 'my_param', got '%s'", p.Name) + } + if p.Kind != "string" { + t.Errorf("Expected kind 'string', got '%s'", p.Kind) + } +} + +// TestLoadAndValidateAgentManifest_SecretParameter verifies that +// secret: true inside the schema block is parsed into Property.Secret. +func TestLoadAndValidateAgentManifest_SecretParameter(t *testing.T) { + yamlContent := []byte(` +name: test-secret +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 +parameters: + api_key: + description: API key for the custom MCP server + schema: + type: string + secret: true + required: true + display_name: + description: A non-secret parameter + schema: + type: string +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 2 { + t.Fatalf("Expected 2 parameters, got %d", len(manifest.Parameters.Properties)) + } + + paramsByName := map[string]Property{} + for _, p := range manifest.Parameters.Properties { + paramsByName[p.Name] = p + } + + // api_key should be secret + apiKey, ok := paramsByName["api_key"] + if !ok { + t.Fatal("Missing parameter api_key") + } + if apiKey.Secret == nil || !*apiKey.Secret { + t.Error("Expected api_key to have secret=true") + } + + // display_name should NOT be secret + displayName, ok := paramsByName["display_name"] + if !ok { + t.Fatal("Missing parameter display_name") + } + if displayName.Secret != nil { + t.Errorf("Expected display_name secret to be nil, got %v", *displayName.Secret) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go index b98afe0f862..a113d7649ca 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go @@ -4,7 +4,10 @@ package agent_yaml import ( + "fmt" "slices" + + "go.yaml.in/yaml/v3" ) // AgentKind represents the type of agent @@ -44,50 +47,16 @@ type ToolKind string const ( ToolKindFunction ToolKind = "function" ToolKindCustom ToolKind = "custom" - ToolKindWebSearch ToolKind = "webSearch" - ToolKindBingGrounding ToolKind = "bingGrounding" - ToolKindFileSearch ToolKind = "fileSearch" + ToolKindWebSearch ToolKind = "web_search" + ToolKindBingGrounding ToolKind = "bing_grounding" + ToolKindFileSearch ToolKind = "file_search" ToolKindMcp ToolKind = "mcp" - ToolKindOpenApi ToolKind = "openApi" - ToolKindCodeInterpreter ToolKind = "codeInterpreter" - ToolKindAzureAiSearch ToolKind = "azureAiSearch" - ToolKindA2APreview ToolKind = "a2aPreview" + ToolKindOpenApi ToolKind = "openapi" + ToolKindCodeInterpreter ToolKind = "code_interpreter" + ToolKindAzureAiSearch ToolKind = "azure_ai_search" + ToolKindA2APreview ToolKind = "a2a_preview" ) -// toolKindAPITypeMap maps camelCase ToolKind values to snake_case API type values. -var toolKindAPITypeMap = map[ToolKind]string{ - ToolKindFunction: "function", - ToolKindCustom: "custom", - ToolKindWebSearch: "web_search", - ToolKindBingGrounding: "bing_grounding", - ToolKindFileSearch: "file_search", - ToolKindMcp: "mcp", - ToolKindOpenApi: "openapi", - ToolKindCodeInterpreter: "code_interpreter", - ToolKindAzureAiSearch: "azure_ai_search", - ToolKindA2APreview: "a2a_preview", -} - -// ToolKindToAPIType converts a camelCase ToolKind to its snake_case API type. -// Returns the input as-is if no mapping is found. -func ToolKindToAPIType(kind ToolKind) string { - if apiType, ok := toolKindAPITypeMap[kind]; ok { - return apiType - } - return string(kind) -} - -// APITypeToToolKind converts a snake_case API type to its camelCase ToolKind. -// Returns the input as a ToolKind if no mapping is found. -func APITypeToToolKind(apiType string) ToolKind { - for kind, mapped := range toolKindAPITypeMap { - if mapped == apiType { - return kind - } - } - return ToolKind(apiType) -} - // AuthType represents the authentication type for a connection. type AuthType string @@ -350,6 +319,7 @@ type Property struct { Default *any `json:"default,omitempty" yaml:"default,omitempty"` Example *any `json:"example,omitempty" yaml:"example,omitempty"` EnumValues *[]any `json:"enumValues,omitempty" yaml:"enumValues,omitempty"` + Secret *bool `json:"secret,omitempty" yaml:"secret,omitempty"` } // ArrayProperty Represents an array property. @@ -370,10 +340,239 @@ type ObjectProperty struct { // PropertySchema Definition for the property schema of a model. // This includes the properties and example records. +// +// The schema supports two YAML layouts for Properties: +// +// Array format (explicit): +// +// properties: +// - name: foo +// kind: string +// +// Record/map format (canonical agent manifest shorthand): +// +// parameters: +// foo: +// schema: { type: string } +// description: a foo param +// required: true +// +// UnmarshalYAML detects which layout is present and normalises to []Property. type PropertySchema struct { - Examples *[]map[string]any `json:"examples,omitempty" yaml:"examples,omitempty"` - Strict *bool `json:"strict,omitempty" yaml:"strict,omitempty"` - Properties []Property `json:"properties" yaml:"properties"` + Examples *[]map[string]any `json:"examples,omitempty" yaml:"-"` + Strict *bool `json:"strict,omitempty" yaml:"-"` + Properties []Property `json:"properties" yaml:"-"` +} + +// UnmarshalYAML supports both the array format (properties: []) and the +// record/map format where parameter names are direct YAML keys. +func (ps *PropertySchema) UnmarshalYAML(value *yaml.Node) error { + // The node should be a mapping. + if value.Kind != yaml.MappingNode { + return fmt.Errorf("PropertySchema: expected mapping node, got %d", value.Kind) + } + + // First pass: look for known struct keys (examples, strict, properties). + // Anything else is treated as a record-format parameter name. + var ( + propertiesNode *yaml.Node + extraKeys []string + extraValues []*yaml.Node + ) + + for i := 0; i < len(value.Content)-1; i += 2 { + key := value.Content[i].Value + val := value.Content[i+1] + + switch key { + case "examples": + var examples []map[string]any + if err := val.Decode(&examples); err != nil { + return fmt.Errorf("PropertySchema.examples: %w", err) + } + ps.Examples = &examples + case "strict": + var strict bool + if err := val.Decode(&strict); err != nil { + return fmt.Errorf("PropertySchema.strict: %w", err) + } + ps.Strict = &strict + case "properties": + propertiesNode = val + default: + extraKeys = append(extraKeys, key) + extraValues = append(extraValues, val) + } + } + + // If an explicit "properties" key was found, decode it (array or map). + if propertiesNode != nil { + props, err := decodePropertiesNode(propertiesNode) + if err != nil { + return fmt.Errorf("PropertySchema.properties: %w", err) + } + ps.Properties = props + return nil + } + + // No explicit "properties" key — treat extra keys as record-format params. + if len(extraKeys) > 0 { + for i, name := range extraKeys { + prop, err := decodeRecordProperty(name, extraValues[i]) + if err != nil { + return fmt.Errorf("PropertySchema parameter %q: %w", name, err) + } + ps.Properties = append(ps.Properties, prop) + } + } + + return nil +} + +// decodePropertiesNode handles "properties:" as either an array or a map. +func decodePropertiesNode(node *yaml.Node) ([]Property, error) { + switch node.Kind { + case yaml.SequenceNode: + var props []Property + if err := node.Decode(&props); err != nil { + return nil, err + } + return props, nil + case yaml.MappingNode: + var props []Property + for i := 0; i < len(node.Content)-1; i += 2 { + name := node.Content[i].Value + prop, err := decodeRecordProperty(name, node.Content[i+1]) + if err != nil { + return nil, fmt.Errorf("property %q: %w", name, err) + } + props = append(props, prop) + } + return props, nil + default: + return nil, fmt.Errorf("expected sequence or mapping, got %d", node.Kind) + } +} + +// recordEntry is the intermediate structure for parsing a single record-format +// parameter entry like: +// +// param_name: +// schema: { type: string, enum: [...], default: ... } +// description: some text +// required: true +type recordEntry struct { + Schema *recordSchema `yaml:"schema"` + Description string `yaml:"description"` + Required bool `yaml:"required"` + Default *any `yaml:"default"` + Example *any `yaml:"example"` + EnumValues *[]any `yaml:"enumValues"` +} + +type recordSchema struct { + Type string `yaml:"type"` + Enum []any `yaml:"enum"` + Default *any `yaml:"default"` + Secret bool `yaml:"secret"` +} + +// decodeRecordProperty converts a record-format parameter entry into a Property. +func decodeRecordProperty(name string, node *yaml.Node) (Property, error) { + var entry recordEntry + if err := node.Decode(&entry); err != nil { + return Property{}, err + } + + prop := Property{Name: name} + if entry.Description != "" { + prop.Description = &entry.Description + } + if entry.Required { + prop.Required = &entry.Required + } + if entry.Default != nil { + prop.Default = entry.Default + } + if entry.Example != nil { + prop.Example = entry.Example + } + if entry.EnumValues != nil { + prop.EnumValues = entry.EnumValues + } + + // Extract kind/default/enum/secret from nested schema if present + if entry.Schema != nil { + prop.Kind = entry.Schema.Type + if entry.Schema.Default != nil && prop.Default == nil { + prop.Default = entry.Schema.Default + } + if len(entry.Schema.Enum) > 0 && prop.EnumValues == nil { + prop.EnumValues = &entry.Schema.Enum + } + if entry.Schema.Secret { + prop.Secret = new(true) + } + } + + return prop, nil +} + +// MarshalYAML writes PropertySchema back as the record/map format so that +// {{param}} placeholders elsewhere in the document survive a marshal→unmarshal +// round-trip through InjectParameterValuesIntoManifest. +func (ps PropertySchema) MarshalYAML() (any, error) { + out := make(map[string]any) + + if ps.Examples != nil { + out["examples"] = *ps.Examples + } + if ps.Strict != nil { + out["strict"] = *ps.Strict + } + + // Emit each property as a record-format entry. + props := make(map[string]any, len(ps.Properties)) + for _, p := range ps.Properties { + entry := map[string]any{} + schema := map[string]any{} + + if p.Kind != "" { + schema["type"] = p.Kind + } + if p.Default != nil { + schema["default"] = *p.Default + } + if p.EnumValues != nil { + schema["enum"] = *p.EnumValues + } + if p.Secret != nil && *p.Secret { + schema["secret"] = true + } + if len(schema) > 0 { + entry["schema"] = schema + } + + if p.Description != nil { + entry["description"] = *p.Description + } + if p.Required != nil { + entry["required"] = *p.Required + } + if p.Example != nil { + entry["example"] = *p.Example + } + props[p.Name] = entry + } + + if len(props) > 0 { + // Merge property keys at the top level (record format) + for k, v := range props { + out[k] = v + } + } + + return out, nil } // ProtocolVersionRecord represents a protocolversionrecord. @@ -431,6 +630,35 @@ type ConnectionResource struct { UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty" yaml:"useWorkspaceManagedIdentity,omitempty"` //nolint:lll } +// ValidateConnectionResource checks that required fields are present and valid. +func ValidateConnectionResource(c *ConnectionResource) error { + if c.Name == "" { + return fmt.Errorf("connection resource is missing required 'name'") + } + if c.Category == "" { + return fmt.Errorf( + "connection resource '%s': category is required", c.Name, + ) + } + if c.Target == "" { + return fmt.Errorf( + "connection resource '%s': target is required", c.Name, + ) + } + if c.AuthType == "" { + return fmt.Errorf( + "connection resource '%s': authType is required", c.Name, + ) + } + if !IsValidAuthType(c.AuthType) { + return fmt.Errorf( + "connection resource '%s': authType must be one of %v, got '%s'", + c.Name, ValidAuthTypes(), c.AuthType, + ) + } + return nil +} + // Template Template model for defining prompt templates. // This model specifies the rendering engine used for slot filling prompts, // the parser used to process the rendered template into API-compatible format, diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go index 25ac5cd6f9d..85256a54d0d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go @@ -24,62 +24,6 @@ func TestArrayProperty_BasicSerialization(t *testing.T) { } } -// TestToolKindToAPIType tests camelCase to snake_case tool kind conversion -func TestToolKindToAPIType(t *testing.T) { - tests := []struct { - input ToolKind - expected string - }{ - {ToolKindFunction, "function"}, - {ToolKindCustom, "custom"}, - {ToolKindWebSearch, "web_search"}, - {ToolKindBingGrounding, "bing_grounding"}, - {ToolKindFileSearch, "file_search"}, - {ToolKindMcp, "mcp"}, - {ToolKindOpenApi, "openapi"}, - {ToolKindCodeInterpreter, "code_interpreter"}, - {ToolKindAzureAiSearch, "azure_ai_search"}, - {ToolKindA2APreview, "a2a_preview"}, - {ToolKind("unknown"), "unknown"}, - } - - for _, tc := range tests { - t.Run(string(tc.input), func(t *testing.T) { - result := ToolKindToAPIType(tc.input) - if result != tc.expected { - t.Errorf("ToolKindToAPIType(%s) = %s, want %s", tc.input, result, tc.expected) - } - }) - } -} - -// TestAPITypeToToolKind tests snake_case to camelCase tool kind conversion -func TestAPITypeToToolKind(t *testing.T) { - tests := []struct { - input string - expected ToolKind - }{ - {"web_search", ToolKindWebSearch}, - {"bing_grounding", ToolKindBingGrounding}, - {"file_search", ToolKindFileSearch}, - {"mcp", ToolKindMcp}, - {"openapi", ToolKindOpenApi}, - {"code_interpreter", ToolKindCodeInterpreter}, - {"azure_ai_search", ToolKindAzureAiSearch}, - {"a2a_preview", ToolKindA2APreview}, - {"unknown", ToolKind("unknown")}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - result := APITypeToToolKind(tc.input) - if result != tc.expected { - t.Errorf("APITypeToToolKind(%s) = %s, want %s", tc.input, result, tc.expected) - } - }) - } -} - // TestIsValidAuthType tests auth type validation func TestIsValidAuthType(t *testing.T) { validTypes := []AuthType{AuthTypeAAD, AuthTypeApiKey, AuthTypeCustomKeys, AuthTypeNone, AuthTypeOAuth2, AuthTypePAT} @@ -212,7 +156,7 @@ func TestAzureAISearchToolSerialization(t *testing.T) { } if tool2.Kind != ToolKindAzureAiSearch { - t.Errorf("Expected kind 'azureAiSearch', got '%s'", tool2.Kind) + t.Errorf("Expected kind 'azure_ai_search', got '%s'", tool2.Kind) } if len(tool2.Indexes) != 1 { t.Fatalf("Expected 1 index, got %d", len(tool2.Indexes)) @@ -246,7 +190,7 @@ func TestA2APreviewToolSerialization(t *testing.T) { } if tool2.Kind != ToolKindA2APreview { - t.Errorf("Expected kind 'a2aPreview', got '%s'", tool2.Kind) + t.Errorf("Expected kind 'a2a_preview', got '%s'", tool2.Kind) } if tool2.BaseUrl != "https://agent.example.com" { t.Errorf("Expected baseUrl 'https://agent.example.com', got '%s'", tool2.BaseUrl) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go index 77586036d92..be342bf8b27 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go @@ -6,6 +6,7 @@ package registry_api import ( "context" "fmt" + "os" "reflect" "strconv" "strings" @@ -15,6 +16,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" + "golang.org/x/term" ) // ParameterValues represents the user-provided values for manifest parameters @@ -346,9 +348,13 @@ func promptForYamlParameterValues( var value any var err error isRequired := property.Required != nil && *property.Required + isSecret := property.Secret != nil && *property.Secret if len(enumValues) > 0 { // Use selection for enum parameters value, err = promptForEnumValue(ctx, property.Name, enumValues, defaultValue, azdClient, noPrompt) + } else if isSecret { + // Use masked input for secret parameters + value, err = promptForSecretValue(property.Name, isRequired) } else { // Use text input for other parameters value, err = promptForTextValue(ctx, property.Name, defaultValue, isRequired, azdClient) @@ -451,7 +457,7 @@ func promptForTextValue( defaultStr = fmt.Sprintf("%v", defaultValue) } - message := fmt.Sprintf("Enter value for parameter '%s':", paramName) + message := fmt.Sprintf("Enter value for parameter '%s'", paramName) if defaultStr != "" { message += fmt.Sprintf(" (default: %s)", defaultStr) } @@ -480,6 +486,24 @@ func promptForTextValue( return resp.Value, nil } +// promptForSecretValue reads a secret value from the terminal with masked input. +func promptForSecretValue(paramName string, required bool) (any, error) { + fmt.Printf("Enter value for parameter '%s': ", paramName) + + secret, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() // newline after masked input + if err != nil { + return nil, fmt.Errorf("failed to read secret value: %w", err) + } + + value := strings.TrimSpace(string(secret)) + if value == "" && required { + return nil, fmt.Errorf("parameter '%s' is required but no value was provided", paramName) + } + + return value, nil +} + // injectParameterValues replaces parameter placeholders in the template with actual values func injectParameterValues(template string, paramValues ParameterValues) ([]byte, error) { // Replace each parameter placeholder with its value diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config.go b/cli/azd/extensions/azure.ai.agents/internal/project/config.go index 41da94f5222..78109795bc2 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/config.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config.go @@ -44,13 +44,14 @@ var ResourceTiers = []ResourceTier{ // ServiceTargetAgentConfig provides custom configuration for the Azure AI Service target type ServiceTargetAgentConfig struct { - Environment map[string]string `json:"env,omitempty"` - Container *ContainerSettings `json:"container,omitempty"` - Deployments []Deployment `json:"deployments,omitempty"` - Resources []Resource `json:"resources,omitempty"` - Toolboxes []Toolbox `json:"toolboxes,omitempty"` - Connections []Connection `json:"connections,omitempty"` - StartupCommand string `json:"startupCommand,omitempty"` + Environment map[string]string `json:"env,omitempty"` + Container *ContainerSettings `json:"container,omitempty"` + Deployments []Deployment `json:"deployments,omitempty"` + Resources []Resource `json:"resources,omitempty"` + ToolConnections []ToolConnection `json:"toolConnections,omitempty"` + Toolboxes []Toolbox `json:"toolboxes,omitempty"` + Connections []Connection `json:"connections,omitempty"` + StartupCommand string `json:"startupCommand,omitempty"` } // ContainerSettings provides container configuration for the Azure AI Service target @@ -134,6 +135,17 @@ type Connection struct { Error string `json:"error,omitempty"` } +// ToolConnection represents a connection to an external service (MCP tool, A2A, custom API) +// that must be created in the Foundry project during provisioning via Bicep. +type ToolConnection struct { + Name string `json:"name"` + Category string `json:"category"` + Target string `json:"target"` + AuthType string `json:"authType"` + Credentials map[string]any `json:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + // UnmarshalStruct converts a structpb.Struct to a Go struct of type T func UnmarshalStruct[T any](s *structpb.Struct, out *T) error { structBytes, err := protojson.Marshal(s) diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go index e6012afe8df..5026950b173 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go @@ -203,3 +203,102 @@ func TestServiceTargetAgentConfig_WithOtherFields(t *testing.T) { t.Errorf("Expected toolbox name 'my-toolbox', got '%s'", roundTripped.Toolboxes[0].Name) } } + +// TestServiceTargetAgentConfig_WithToolConnections tests MarshalStruct/UnmarshalStruct +// round-trip with the ToolConnections field populated. +func TestServiceTargetAgentConfig_WithToolConnections(t *testing.T) { + original := ServiceTargetAgentConfig{ + ToolConnections: []ToolConnection{ + { + Name: "github-mcp", + Category: "RemoteTool", + Target: "https://api.githubcopilot.com/mcp", + AuthType: "OAuth2", + Credentials: map[string]any{ //nolint:gosec // test data, not real credentials + "clientId": "${GITHUB_CLIENT_ID}", + "clientSecret": "${GITHUB_CLIENT_SECRET}", + }, + Metadata: map[string]string{ + "ApiType": "Azure", + }, + }, + }, + Toolboxes: []Toolbox{ + { + Name: "platform-tools", + Tools: []map[string]any{ + { + "type": "mcp", + "project_connection_id": "github-mcp", + "server_url": "https://api.githubcopilot.com/mcp", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.ToolConnections) != 1 { + t.Fatalf("Expected 1 tool connection, got %d", len(roundTripped.ToolConnections)) + } + + conn := roundTripped.ToolConnections[0] + if conn.Name != "github-mcp" { + t.Errorf("Expected connection name 'github-mcp', got '%s'", conn.Name) + } + if conn.Category != "RemoteTool" { + t.Errorf("Expected category 'RemoteTool', got '%s'", conn.Category) + } + if conn.Target != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected target 'https://api.githubcopilot.com/mcp', got '%s'", conn.Target) + } + if conn.AuthType != "OAuth2" { + t.Errorf("Expected authType 'OAuth2', got '%s'", conn.AuthType) + } + if conn.Credentials["clientId"] != "${GITHUB_CLIENT_ID}" { + t.Errorf("Expected clientId '${GITHUB_CLIENT_ID}', got '%v'", conn.Credentials["clientId"]) + } + if conn.Metadata["ApiType"] != "Azure" { + t.Errorf("Expected metadata ApiType 'Azure', got '%s'", conn.Metadata["ApiType"]) + } + + // Verify toolbox is also preserved + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + if roundTripped.Toolboxes[0].Tools[0]["project_connection_id"] != "github-mcp" { + t.Errorf("Expected project_connection_id 'github-mcp', got '%v'", + roundTripped.Toolboxes[0].Tools[0]["project_connection_id"]) + } +} + +// TestServiceTargetAgentConfig_EmptyToolConnections tests that an empty ToolConnections +// slice is omitted and doesn't break round-trip. +func TestServiceTargetAgentConfig_EmptyToolConnections(t *testing.T) { + original := ServiceTargetAgentConfig{ + ToolConnections: []ToolConnection{}, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.ToolConnections) != 0 { + t.Errorf("Expected 0 tool connections, got %d", len(roundTripped.ToolConnections)) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index 4b6f88a78fd..f2bc0682c60 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -6,10 +6,8 @@ package project import ( "context" "encoding/base64" - "errors" "fmt" "math" - "net/http" "os" "path/filepath" "strconv" @@ -21,7 +19,6 @@ import ( "azureaiagent/internal/pkg/agents/agent_yaml" "azureaiagent/internal/pkg/azure" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" @@ -654,6 +651,21 @@ func (p *AgentServiceTargetProvider) deployHostedAgent( memory = foundryAgentConfig.Container.Resources.Memory } + // Auto-inject toolbox MCP endpoint env vars so hosted agents can reach their toolboxes + // without requiring users to manually add them to agent.yaml's environment_variables. + if foundryAgentConfig != nil { + projectEndpoint := strings.TrimRight(azdEnv["AZURE_AI_PROJECT_ENDPOINT"], "/") + for _, toolbox := range foundryAgentConfig.Toolboxes { + toolboxKey := p.getServiceKey(toolbox.Name) + envKey := fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) + if _, exists := resolvedEnvVars[envKey]; !exists { + resolvedEnvVars[envKey] = fmt.Sprintf( + "%s/toolsets/%s/mcp", projectEndpoint, toolbox.Name, + ) + } + } + } + // Build options list starting with required options options := []agent_yaml.AgentBuildOption{ agent_yaml.WithImageURL(fullImageURL), @@ -1078,110 +1090,6 @@ func (p *AgentServiceTargetProvider) resolveEnvironmentVariables(value string, a return resolved } -// deployToolboxes creates or updates Foundry Toolsets for each toolbox in the config. -// For each toolbox, it calls the Foundry Toolsets API to upsert the toolset, then -// sets environment variables with the MCP endpoints from the toolset's MCP tools. -func (p *AgentServiceTargetProvider) deployToolboxes( - ctx context.Context, - serviceTargetConfig *ServiceTargetAgentConfig, - azdEnv map[string]string, -) error { - if serviceTargetConfig == nil || len(serviceTargetConfig.Toolboxes) == 0 { - return nil - } - - projectEndpoint := azdEnv["AZURE_AI_PROJECT_ENDPOINT"] - if projectEndpoint == "" { - return exterrors.Dependency( - exterrors.CodeMissingAiProjectEndpoint, - "AZURE_AI_PROJECT_ENDPOINT is required for toolbox deployment", - "run 'azd provision' or connect to an existing project", - ) - } - - toolsetsClient := azure.NewFoundryToolsetsClient(projectEndpoint, p.credential) - - for _, toolbox := range serviceTargetConfig.Toolboxes { - fmt.Fprintf(os.Stderr, "Deploying toolbox: %s\n", toolbox.Name) - - _, err := p.upsertToolset(ctx, toolsetsClient, toolbox) - if err != nil { - return err - } - - // Set the MCP endpoint env var now that the toolbox is confirmed to exist - if err := p.registerToolboxEnvironmentVariables( - ctx, projectEndpoint, toolbox.Name, - ); err != nil { - return err - } - - fmt.Fprintf(os.Stderr, "Toolbox '%s' deployed successfully\n", toolbox.Name) - } - - return nil -} - -// upsertToolset creates a toolset, or updates it if it already exists. -// A 409 Conflict on create means the toolset already exists, which is treated as success. -func (p *AgentServiceTargetProvider) upsertToolset( - ctx context.Context, - client *azure.FoundryToolsetsClient, - toolbox Toolbox, -) (*azure.ToolsetObject, error) { - createReq := &azure.CreateToolsetRequest{ - Name: toolbox.Name, - Description: toolbox.Description, - Tools: toolbox.Tools, - } - - toolset, err := client.CreateToolset(ctx, createReq) - if err == nil { - return toolset, nil - } - - // 409 Conflict means the toolset already exists — treat as success - var respErr *azcore.ResponseError - if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { - fmt.Fprintf(os.Stderr, " Toolset '%s' already exists, skipping update\n", toolbox.Name) - return nil, nil - } - - return nil, exterrors.Internal( - exterrors.CodeCreateToolsetFailed, - fmt.Sprintf("failed to create toolset '%s': %s", toolbox.Name, err), - ) -} - -// registerToolboxEnvironmentVariables sets the FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT env var -// with the constructed MCP endpoint URL: {projectEndpoint}/toolsets/{toolboxName}/mcp -func (p *AgentServiceTargetProvider) registerToolboxEnvironmentVariables( - ctx context.Context, - projectEndpoint string, - toolboxName string, -) error { - toolboxKey := p.getServiceKey(toolboxName) - envKey := fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) - - endpoint := strings.TrimRight(projectEndpoint, "/") - mcpEndpoint := fmt.Sprintf("%s/toolsets/%s/mcp", endpoint, toolboxName) - - _, err := p.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ - EnvName: p.env.Name, - Key: envKey, - Value: mcpEndpoint, - }) - if err != nil { - return fmt.Errorf( - "failed to set environment variable %s: %w", - envKey, err, - ) - } - - fmt.Fprintf(os.Stderr, " Set %s=%s\n", envKey, mcpEndpoint) - return nil -} - // ensureFoundryProject ensures the Foundry project resource ID is parsed and stored. // Checks for AZURE_AI_PROJECT_ID environment variable. func (p *AgentServiceTargetProvider) ensureFoundryProject(ctx context.Context) error { diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go index f23f15d7348..6c11e7ce59e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go @@ -108,3 +108,30 @@ func TestApplyVnextMetadata(t *testing.T) { }) } } + +func TestGetServiceKey_NormalizesToolboxNames(t *testing.T) { + t.Parallel() + + p := &AgentServiceTargetProvider{} + + tests := []struct { + name string + input string + expected string + }{ + {"hyphens", "agent-tools", "AGENT_TOOLS"}, + {"spaces", "agent tools", "AGENT_TOOLS"}, + {"mixed", "my-agent tools", "MY_AGENT_TOOLS"}, + {"already upper", "TOOLS", "TOOLS"}, + {"lowercase", "tools", "TOOLS"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.getServiceKey(tt.input) + if got != tt.expected { + t.Errorf("getServiceKey(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json index 7365e9f8f1a..bf8b113ecda 100644 --- a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json +++ b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json @@ -24,6 +24,11 @@ "description": "List of external resources for agent execution.", "items": { "$ref": "#/definitions/Resource" } }, + "toolConnections": { + "type": "array", + "description": "List of tool connections to external services (MCP tools, A2A, custom APIs) created during provisioning.", + "items": { "$ref": "#/definitions/ToolConnection" } + }, "toolboxes": { "type": "array", "description": "List of toolboxes (Foundry Toolsets) to deploy.", @@ -130,6 +135,31 @@ "required": ["resource", "connectionName"], "additionalProperties": false }, + "ToolConnection": { + "type": "object", + "description": "A connection to an external service (MCP tool, A2A, custom API) created via Bicep during provisioning.", + "properties": { + "name": { "type": "string", "description": "Connection name used as project_connection_id in toolbox tools." }, + "category": { "type": "string", "description": "Connection category (e.g., 'RemoteTool')." }, + "target": { "type": "string", "description": "Target endpoint URL for the connection." }, + "authType": { + "type": "string", + "description": "Authentication type for the connection.", + "enum": ["AAD", "AccessKey", "AccountKey", "ApiKey", "CustomKeys", "ManagedIdentity", "None", "OAuth2", "PAT", "ServicePrincipal", "UsernamePassword"] + }, + "credentials": { + "type": "object", + "description": "Credentials for the connection. Values may contain ${ENV_VAR} references resolved at provision time." + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the connection.", + "additionalProperties": { "type": "string" } + } + }, + "required": ["name", "category", "target", "authType"], + "additionalProperties": false + }, "Toolbox": { "type": "object", "description": "A reusable collection of tools deployed as a Foundry Toolset.", From c594debe3da0352ebddf99c18d1cd7c7d907abbf Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 08:43:28 -0700 Subject: [PATCH 04/14] Missing test file Signed-off-by: trangevi --- .../internal/cmd/listen_test.go | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go new file mode 100644 index 00000000000..7861f076950 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" +) + +func TestNormalizeCredentials_CustomKeys_FlatToNested(t *testing.T) { + t.Parallel() + + // Old-format flat credentials should be wrapped under "keys" + creds := map[string]any{"key": "${FOUNDRY_TOOL_CONTEXT7_KEY}"} + result := normalizeCredentials("CustomKeys", creds) + + keysRaw, ok := result["keys"] + if !ok { + t.Fatal("Expected 'keys' wrapper in normalized credentials") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${FOUNDRY_TOOL_CONTEXT7_KEY}" { + t.Errorf("Expected key value preserved, got %v", keys["key"]) + } +} + +func TestNormalizeCredentials_CustomKeys_AlreadyNested(t *testing.T) { + t.Parallel() + + // Already-correct nested credentials should be returned as-is + creds := map[string]any{ + "keys": map[string]any{"key": "${FOUNDRY_TOOL_CONTEXT7_KEY}"}, + } + result := normalizeCredentials("CustomKeys", creds) + + keysRaw, ok := result["keys"] + if !ok { + t.Fatal("Expected 'keys' wrapper preserved") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${FOUNDRY_TOOL_CONTEXT7_KEY}" { + t.Errorf("Expected key value preserved, got %v", keys["key"]) + } + if len(result) != 1 { + t.Errorf("Expected only 'keys' in result, got %d entries", len(result)) + } +} + +func TestNormalizeCredentials_OAuth2_Unchanged(t *testing.T) { + t.Parallel() + + // Non-CustomKeys auth types should be returned unchanged + creds := map[string]any{ + "clientId": "${VAR_ID}", + "clientSecret": "${VAR_SECRET}", + } + result := normalizeCredentials("OAuth2", creds) + + if _, hasKeys := result["keys"]; hasKeys { + t.Error("OAuth2 credentials should not be wrapped in 'keys'") + } + if result["clientId"] != "${VAR_ID}" { + t.Errorf("Expected clientId preserved, got %v", result["clientId"]) + } +} + +func TestNormalizeCredentials_EmptyCredentials(t *testing.T) { + t.Parallel() + + result := normalizeCredentials("CustomKeys", nil) + if result != nil { + t.Errorf("Expected nil for nil input, got %v", result) + } + + result = normalizeCredentials("CustomKeys", map[string]any{}) + if len(result) != 0 { + t.Errorf("Expected empty map for empty input, got %v", result) + } +} From f4a1a122185435a70a7b42a91b3b6da05ae4acf3 Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 11:01:35 -0700 Subject: [PATCH 05/14] Some more fixes Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/init.go | 9 ++ .../azure.ai.agents/internal/cmd/init_test.go | 90 ++++++++++++++++++ .../azure.ai.agents/internal/cmd/listen.go | 52 +++++++++++ .../internal/cmd/listen_test.go | 91 +++++++++++++++++++ .../internal/pkg/agents/agent_yaml/parse.go | 3 - .../pkg/agents/agent_yaml/parse_test.go | 91 ------------------- .../internal/pkg/agents/agent_yaml/yaml.go | 46 ---------- .../pkg/agents/agent_yaml/yaml_test.go | 71 --------------- 8 files changed, 242 insertions(+), 211 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index 7f18812079f..78e41399981 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -1806,6 +1806,15 @@ func extractConnectionConfigs(manifest *agent_yaml.AgentManifest) ([]project.Con Error: connResource.Error, } + // Surface credentials.type to top-level authType when not explicitly set. + // The API expects authType at the connection level, not nested in credentials. + if conn.AuthType == "" { + if credType, ok := conn.Credentials["type"].(string); ok && credType != "" { + conn.AuthType = credType + delete(conn.Credentials, "type") + } + } + connections = append(connections, conn) } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go index eacb211a10f..3daac47fb61 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go @@ -850,3 +850,93 @@ func TestToolboxMCPEndpointEnvKey(t *testing.T) { }) } } + +func TestExtractConnectionConfigs_SurfacesCredentialsType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connResource agent_yaml.ConnectionResource + wantAuthType string + wantCredHasType bool + wantCredKeyCount int + }{ + { + name: "surfaces credentials.type to authType when authType is empty", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + Credentials: map[string]any{ + "type": "CustomKeys", + "key": "secret-value", + }, + }, + wantAuthType: "CustomKeys", + wantCredHasType: false, + wantCredKeyCount: 1, + }, + { + name: "preserves explicit authType even if credentials.type differs", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + AuthType: agent_yaml.AuthTypeAAD, + Credentials: map[string]any{ + "type": "CustomKeys", + "key": "val", + }, + }, + wantAuthType: string(agent_yaml.AuthTypeAAD), + wantCredHasType: true, + wantCredKeyCount: 2, + }, + { + name: "no credentials.type and no authType stays empty", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + Credentials: map[string]any{"key": "val"}, + }, + wantAuthType: "", + wantCredHasType: false, + wantCredKeyCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manifest := &agent_yaml.AgentManifest{ + Resources: []any{tt.connResource}, + } + conns, err := extractConnectionConfigs(manifest) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(conns) != 1 { + t.Fatalf("expected 1 connection, got %d", len(conns)) + } + conn := conns[0] + if conn.AuthType != tt.wantAuthType { + t.Errorf("AuthType = %q, want %q", conn.AuthType, tt.wantAuthType) + } + _, hasType := conn.Credentials["type"] + if hasType != tt.wantCredHasType { + t.Errorf("credentials has 'type' = %v, want %v", + hasType, tt.wantCredHasType) + } + if len(conn.Credentials) != tt.wantCredKeyCount { + t.Errorf("credentials key count = %d, want %d", + len(conn.Credentials), tt.wantCredKeyCount) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index a4ab4e33022..991d8f8701f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -523,6 +523,9 @@ func provisionToolboxes( return fmt.Errorf("failed to load environment variables: %w", err) } + // Build connection ID lookup from bicep outputs (name → ARM resource ID) + connIDMap := parseConnectionIDs(azdEnv["AI_PROJECT_CONNECTION_IDS_JSON"]) + // Build connection lookup for enriching tool entries with server_url/server_label connByName := map[string]project.ToolConnection{} for _, c := range config.ToolConnections { @@ -540,6 +543,9 @@ func provisionToolboxes( // Fill in server_url/server_label from connection data enrichToolboxFromConnections(&toolbox, connByName) + // Replace project_connection_id friendly names with ARM resource IDs + resolveToolboxConnectionIDs(&toolbox, connIDMap) + if err := upsertToolset( ctx, toolsetsClient, toolbox, ); err != nil { @@ -671,6 +677,52 @@ func enrichToolboxFromConnections( } } +// parseConnectionIDs parses the AI_PROJECT_CONNECTION_IDS_JSON env var +// (a JSON array of {name, id} objects) into a map of name → ARM resource ID. +func parseConnectionIDs(jsonStr string) map[string]string { + result := map[string]string{} + if jsonStr == "" { + return result + } + + var entries []struct { + Name string `json:"name"` + ID string `json:"id"` + } + if err := json.Unmarshal([]byte(jsonStr), &entries); err != nil { + fmt.Fprintf(os.Stderr, + "Warning: failed to parse AI_PROJECT_CONNECTION_IDS_JSON: %s\n", err) + return result + } + + for _, e := range entries { + if e.Name != "" && e.ID != "" { + result[e.Name] = e.ID + } + } + return result +} + +// resolveToolboxConnectionIDs replaces project_connection_id friendly names +// with their actual ARM resource IDs from bicep provisioning outputs. +func resolveToolboxConnectionIDs( + toolbox *project.Toolbox, + connIDs map[string]string, +) { + if len(connIDs) == 0 { + return + } + for i, tool := range toolbox.Tools { + connName, _ := tool["project_connection_id"].(string) + if connName == "" { + continue + } + if armID, ok := connIDs[connName]; ok { + toolbox.Tools[i]["project_connection_id"] = armID + } + } +} + // resolveEnvValue resolves ${VAR} references in a string using envsubst. func resolveEnvValue(value string, azdEnv map[string]string) string { resolved, err := envsubst.Eval(value, func(varName string) string { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go index 7861f076950..81aeffbf02b 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go @@ -5,6 +5,8 @@ package cmd import ( "testing" + + "azureaiagent/internal/project" ) func TestNormalizeCredentials_CustomKeys_FlatToNested(t *testing.T) { @@ -83,3 +85,92 @@ func TestNormalizeCredentials_EmptyCredentials(t *testing.T) { t.Errorf("Expected empty map for empty input, got %v", result) } } + +func TestParseConnectionIDs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + json string + expected map[string]string + }{ + { + name: "valid array", + json: `[{"name":"my-conn","id":"/subscriptions/123/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/ai/projects/proj/connections/my-conn"}]`, + expected: map[string]string{"my-conn": "/subscriptions/123/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/ai/projects/proj/connections/my-conn"}, + }, + { + name: "empty string", + json: "", + expected: map[string]string{}, + }, + { + name: "empty array", + json: "[]", + expected: map[string]string{}, + }, + { + name: "invalid JSON", + json: "not-json", + expected: map[string]string{}, + }, + { + name: "multiple connections", + json: `[{"name":"conn-a","id":"id-a"},{"name":"conn-b","id":"id-b"}]`, + expected: map[string]string{ + "conn-a": "id-a", + "conn-b": "id-b", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseConnectionIDs(tt.json) + if len(result) != len(tt.expected) { + t.Fatalf("got %d entries, want %d", len(result), len(tt.expected)) + } + for k, v := range tt.expected { + if result[k] != v { + t.Errorf("key %q: got %q, want %q", k, result[k], v) + } + } + }) + } +} + +func TestResolveToolboxConnectionIDs(t *testing.T) { + t.Parallel() + + connIDs := map[string]string{ + "github_mcp_connection": "/subscriptions/123/connections/github_mcp_connection", + } + + toolbox := project.Toolbox{ + Name: "test", + Tools: []map[string]any{ + {"type": "web_search"}, + {"type": "mcp", "project_connection_id": "github_mcp_connection"}, + {"type": "mcp", "project_connection_id": "unknown_conn"}, + }, + } + + resolveToolboxConnectionIDs(&toolbox, connIDs) + + // Tool without project_connection_id: unchanged + if _, has := toolbox.Tools[0]["project_connection_id"]; has { + t.Error("tool 0 should not have project_connection_id") + } + + // Known connection: replaced with ARM ID + if toolbox.Tools[1]["project_connection_id"] != "/subscriptions/123/connections/github_mcp_connection" { + t.Errorf("tool 1 project_connection_id = %v, want ARM ID", + toolbox.Tools[1]["project_connection_id"]) + } + + // Unknown connection: left as-is + if toolbox.Tools[2]["project_connection_id"] != "unknown_conn" { + t.Errorf("tool 2 project_connection_id = %v, want 'unknown_conn'", + toolbox.Tools[2]["project_connection_id"]) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go index 0a56af702f4..3048bed1f6e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go @@ -181,9 +181,6 @@ func ExtractResourceDefinitions(manifestYamlContent []byte) ([]any, error) { if err := yaml.Unmarshal(resourceBytes, &connDef); err != nil { return nil, fmt.Errorf("failed to unmarshal to ConnectionResource: %w", err) } - if err := ValidateConnectionResource(&connDef); err != nil { - return nil, err - } resourceDefs = append(resourceDefs, connDef) default: return nil, fmt.Errorf("unrecognized resource kind: %s", resourceDef.Kind) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index a07db730013..e666a1e83d3 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -375,97 +375,6 @@ resources: } } -// TestExtractResourceDefinitions_ConnectionMissingRequired tests that missing required fields fail validation -func TestExtractResourceDefinitions_ConnectionMissingRequired(t *testing.T) { - testCases := []struct { - name string - yaml string - errMatch string - }{ - { - name: "missing category", - yaml: ` -name: test-manifest -template: - kind: prompt - name: test-agent - model: - id: gpt-4.1-mini -resources: - - kind: connection - name: my-conn - target: https://example.com - authType: None -`, - errMatch: "category is required", - }, - { - name: "missing target", - yaml: ` -name: test-manifest -template: - kind: prompt - name: test-agent - model: - id: gpt-4.1-mini -resources: - - kind: connection - name: my-conn - category: CustomKeys - authType: None -`, - errMatch: "target is required", - }, - { - name: "missing authType", - yaml: ` -name: test-manifest -template: - kind: prompt - name: test-agent - model: - id: gpt-4.1-mini -resources: - - kind: connection - name: my-conn - category: CustomKeys - target: https://example.com -`, - errMatch: "authType is required", - }, - { - name: "invalid authType", - yaml: ` -name: test-manifest -template: - kind: prompt - name: test-agent - model: - id: gpt-4.1-mini -resources: - - kind: connection - name: my-conn - category: CustomKeys - target: https://example.com - authType: InvalidAuth -`, - errMatch: "authType must be one of", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := ExtractResourceDefinitions([]byte(tc.yaml)) - if err == nil { - t.Fatal("Expected validation error, got nil") - } - if !strings.Contains(err.Error(), tc.errMatch) { - t.Errorf("Expected error to contain '%s', got '%s'", tc.errMatch, err.Error()) - } - }) - } -} - // TestExtractResourceDefinitions_AllResourceKinds tests model + toolbox + connection together func TestExtractResourceDefinitions_AllResourceKinds(t *testing.T) { yamlContent := []byte(` diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go index a113d7649ca..13038ee5ccd 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go @@ -69,23 +69,6 @@ const ( AuthTypePAT AuthType = "PAT" ) -// ValidAuthTypes returns a slice of all supported AuthType values. -func ValidAuthTypes() []AuthType { - return []AuthType{ - AuthTypeAAD, - AuthTypeApiKey, - AuthTypeCustomKeys, - AuthTypeNone, - AuthTypeOAuth2, - AuthTypePAT, - } -} - -// IsValidAuthType checks if the provided AuthType is valid. -func IsValidAuthType(authType AuthType) bool { - return slices.Contains(ValidAuthTypes(), authType) -} - // CategoryKind represents the category of a connection resource. type CategoryKind string @@ -630,35 +613,6 @@ type ConnectionResource struct { UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty" yaml:"useWorkspaceManagedIdentity,omitempty"` //nolint:lll } -// ValidateConnectionResource checks that required fields are present and valid. -func ValidateConnectionResource(c *ConnectionResource) error { - if c.Name == "" { - return fmt.Errorf("connection resource is missing required 'name'") - } - if c.Category == "" { - return fmt.Errorf( - "connection resource '%s': category is required", c.Name, - ) - } - if c.Target == "" { - return fmt.Errorf( - "connection resource '%s': target is required", c.Name, - ) - } - if c.AuthType == "" { - return fmt.Errorf( - "connection resource '%s': authType is required", c.Name, - ) - } - if !IsValidAuthType(c.AuthType) { - return fmt.Errorf( - "connection resource '%s': authType must be one of %v, got '%s'", - c.Name, ValidAuthTypes(), c.AuthType, - ) - } - return nil -} - // Template Template model for defining prompt templates. // This model specifies the rendering engine used for slot filling prompts, // the parser used to process the rendered template into API-compatible format, diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go index 85256a54d0d..398cceb368e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go @@ -24,77 +24,6 @@ func TestArrayProperty_BasicSerialization(t *testing.T) { } } -// TestIsValidAuthType tests auth type validation -func TestIsValidAuthType(t *testing.T) { - validTypes := []AuthType{AuthTypeAAD, AuthTypeApiKey, AuthTypeCustomKeys, AuthTypeNone, AuthTypeOAuth2, AuthTypePAT} - for _, at := range validTypes { - if !IsValidAuthType(at) { - t.Errorf("IsValidAuthType(%s) should be true", at) - } - } - - if IsValidAuthType("InvalidType") { - t.Error("IsValidAuthType(InvalidType) should be false") - } -} - -// TestValidateConnectionResource tests connection resource validation -func TestValidateConnectionResource(t *testing.T) { - tests := []struct { - name string - conn ConnectionResource - wantErr bool - }{ - { - name: "valid connection", - conn: ConnectionResource{ - Resource: Resource{Name: "test", Kind: ResourceKindConnection}, - Category: CategoryCustomKeys, - Target: "https://example.com", - AuthType: AuthTypeCustomKeys, - }, - wantErr: false, - }, - { - name: "missing name", - conn: ConnectionResource{ - Category: CategoryCustomKeys, - Target: "https://example.com", - AuthType: AuthTypeCustomKeys, - }, - wantErr: true, - }, - { - name: "missing category", - conn: ConnectionResource{ - Resource: Resource{Name: "test", Kind: ResourceKindConnection}, - Target: "https://example.com", - AuthType: AuthTypeCustomKeys, - }, - wantErr: true, - }, - { - name: "invalid auth type", - conn: ConnectionResource{ - Resource: Resource{Name: "test", Kind: ResourceKindConnection}, - Category: CategoryCustomKeys, - Target: "https://example.com", - AuthType: "BadAuth", - }, - wantErr: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := ValidateConnectionResource(&tc.conn) - if (err != nil) != tc.wantErr { - t.Errorf("ValidateConnectionResource() error = %v, wantErr %v", err, tc.wantErr) - } - }) - } -} - // TestConnectionResourceSerialization tests JSON round-trip for ConnectionResource func TestConnectionResourceSerialization(t *testing.T) { conn := ConnectionResource{ From b238af1313f0fc9d299a7e1caec734f84fee8f5e Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 11:47:13 -0700 Subject: [PATCH 06/14] Fix merge Signed-off-by: trangevi --- .../init_foundry_resources_helpers_test.go | 43 ------------------- .../azure.ai.agents/internal/cmd/listen.go | 26 +---------- 2 files changed, 1 insertion(+), 68 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go index 83f8dcc5f78..5a1153b014d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_foundry_resources_helpers_test.go @@ -97,49 +97,6 @@ func TestExtractProjectDetails(t *testing.T) { } } -func TestNormalizeLoginServer(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - want string - }{ - { - name: "plain hostname", - input: "myregistry.azurecr.io", - want: "myregistry.azurecr.io", - }, - { - name: "https prefix", - input: "https://myregistry.azurecr.io", - want: "myregistry.azurecr.io", - }, - { - name: "http prefix", - input: "http://myregistry.azurecr.io", - want: "myregistry.azurecr.io", - }, - { - name: "https with trailing slash", - input: "https://myregistry.azurecr.io/", - want: "myregistry.azurecr.io", - }, - { - name: "empty string", - input: "", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.want, normalizeLoginServer(tt.input)) - }) - } -} - func TestFoundryProjectInfoResourceIdConstruction(t *testing.T) { t.Parallel() diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index 2fdf1c48f64..bc3ad2ce282 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -75,11 +75,7 @@ func newListenCommand() *cobra.Command { } } -func preprovisionHandler(ctx context.Context, azdClient *azdext.AzdClient, projectParser *project.FoundryParser, args *azdext.ProjectEventArgs) error { - if err := projectParser.SetIdentity(ctx, args); err != nil { - return fmt.Errorf("failed to set identity: %w", err) - } - +func preprovisionHandler(ctx context.Context, azdClient *azdext.AzdClient, args *azdext.ProjectEventArgs) error { for _, svc := range args.Project.Services { switch svc.Host { case AiAgentHost: @@ -116,26 +112,6 @@ func postprovisionHandler( return nil } -func predeployHandler(ctx context.Context, azdClient *azdext.AzdClient, projectParser *project.FoundryParser, args *azdext.ProjectEventArgs) error { - if err := projectParser.SetIdentity(ctx, args); err != nil { - return fmt.Errorf("failed to set identity: %w", err) - } - - for _, svc := range args.Project.Services { - switch svc.Host { - case AiAgentHost: - if err := populateContainerSettings(ctx, azdClient, svc); err != nil { - return fmt.Errorf("failed to populate container settings for service %q: %w", svc.Name, err) - } - if err := envUpdate(ctx, azdClient, args.Project, svc); err != nil { - return fmt.Errorf("failed to update environment for service %q: %w", svc.Name, err) - } - } - } - - return nil -} - func predeployHandler(ctx context.Context, azdClient *azdext.AzdClient, args *azdext.ProjectEventArgs) error { for _, svc := range args.Project.Services { switch svc.Host { From 224c27ccdb5ecb6039eac15b58545141a6fac695 Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 15:28:10 -0700 Subject: [PATCH 07/14] More fixes Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/init.go | 6 +- .../internal/cmd/init_from_code.go | 8 +- .../internal/cmd/init_from_code_test.go | 24 ++-- .../azure.ai.agents/internal/cmd/init_test.go | 40 +++--- .../azure.ai.agents/internal/cmd/listen.go | 60 ++++++++- .../internal/cmd/listen_test.go | 114 +++++++++++++++++- .../internal/pkg/agents/agent_yaml/map.go | 2 +- .../internal/project/service_target_agent.go | 2 +- 8 files changed, 205 insertions(+), 51 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index f6c9dc87e3d..b0a99a75c84 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -1798,13 +1798,13 @@ func extractToolboxAndConnectionConfigs( } // credentialEnvVarName builds a deterministic env var name for a connection -// credential key, e.g. ("github-copilot", "clientSecret") → "FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET". +// credential key, e.g. ("github-copilot", "clientSecret") → "TOOL_GITHUB_COPILOT_CLIENTSECRET". func credentialEnvVarName(connName, key string) string { - s := "FOUNDRY_TOOL_" + strings.ToUpper(connName) + "_" + strings.ToUpper(key) + s := "TOOL_" + strings.ToUpper(connName) + "_" + strings.ToUpper(key) return strings.ReplaceAll(s, "-", "_") } -// injectToolboxEnvVarsIntoDefinition adds FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT entries +// injectToolboxEnvVarsIntoDefinition adds TOOLBOX_{NAME}_MCP_ENDPOINT entries // to the environment_variables section of a hosted agent definition for each toolbox // resource in the manifest. Entries already present in the definition are not overwritten. func injectToolboxEnvVarsIntoDefinition(manifest *agent_yaml.AgentManifest) { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go index 166f02b7601..f7df342ca52 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go @@ -848,13 +848,13 @@ type protocolInfo struct { // knownProtocols lists the protocols offered during init, in display order. var knownProtocols = []protocolInfo{ - {Name: "responses", Version: "v1"}, - {Name: "invocations", Version: "v0.0.1"}, + {Name: "responses", Version: "1.0.0"}, + {Name: "invocations", Version: "1.0.0"}, } // promptProtocols asks the user which protocols their agent supports. // When flagProtocols is non-empty the prompt is skipped and those values are used directly. -// When noPrompt is true and no flag values are provided, defaults to [responses/v1]. +// When noPrompt is true and no flag values are provided, defaults to [responses/1.0.0]. func promptProtocols( ctx context.Context, promptClient azdext.PromptServiceClient, @@ -897,7 +897,7 @@ func promptProtocols( // Non-interactive mode: default to responses. if noPrompt { return []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, nil } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go index 968eabdefa7..e32bda5bddc 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go @@ -412,7 +412,7 @@ func TestWriteDefinitionToSrcDir(t *testing.T) { Kind: agent_yaml.AgentKindHosted, }, Protocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, @@ -601,22 +601,22 @@ func TestPromptProtocols_FlagValues(t *testing.T) { name: "responses only", flagProtocols: []string{"responses"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, }, { name: "invocations only", flagProtocols: []string{"invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { name: "both protocols", flagProtocols: []string{"responses", "invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { @@ -629,8 +629,8 @@ func TestPromptProtocols_FlagValues(t *testing.T) { name: "duplicates are removed", flagProtocols: []string{"responses", "responses", "invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, } @@ -680,8 +680,8 @@ func TestPromptProtocols_NoPromptDefault(t *testing.T) { if got[0].Protocol != "responses" { t.Errorf("protocol = %q, want %q", got[0].Protocol, "responses") } - if got[0].Version != "v1" { - t.Errorf("version = %q, want %q", got[0].Version, "v1") + if got[0].Version != "1.0.0" { + t.Errorf("version = %q, want %q", got[0].Version, "1.0.0") } } @@ -736,8 +736,8 @@ func TestPromptProtocols_Interactive(t *testing.T) { }, nil }, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { @@ -751,7 +751,7 @@ func TestPromptProtocols_Interactive(t *testing.T) { }, nil }, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, }, { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go index 3daac47fb61..d7e723dc323 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go @@ -453,21 +453,21 @@ func TestExtractToolboxAndConnectionConfigs_TypedTools(t *testing.T) { } // Credentials should be ${VAR} references, not raw values - if conn.Credentials["clientId"] != "${FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID}" { + if conn.Credentials["clientId"] != "${TOOL_GITHUB_COPILOT_CLIENTID}" { t.Errorf("Expected env var ref for clientId, got '%v'", conn.Credentials["clientId"]) } - if conn.Credentials["clientSecret"] != "${FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET}" { + if conn.Credentials["clientSecret"] != "${TOOL_GITHUB_COPILOT_CLIENTSECRET}" { t.Errorf("Expected env var ref for clientSecret, got '%v'", conn.Credentials["clientSecret"]) } // Raw values should be in the credEnvVars map - if credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID"] != "my-client-id" { + if credEnvVars["TOOL_GITHUB_COPILOT_CLIENTID"] != "my-client-id" { t.Errorf("Expected env var value 'my-client-id', got '%s'", - credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTID"]) + credEnvVars["TOOL_GITHUB_COPILOT_CLIENTID"]) } - if credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET"] != "my-secret" { + if credEnvVars["TOOL_GITHUB_COPILOT_CLIENTSECRET"] != "my-secret" { t.Errorf("Expected env var value 'my-secret', got '%s'", - credEnvVars["FOUNDRY_TOOL_GITHUB_COPILOT_CLIENTSECRET"]) + credEnvVars["TOOL_GITHUB_COPILOT_CLIENTSECRET"]) } // Verify toolbox has both tools @@ -620,7 +620,7 @@ func TestExtractToolboxAndConnectionConfigs_CustomKeysCredentials(t *testing.T) if !ok { t.Fatalf("Expected 'keys' to be map[string]any, got %T", keysRaw) } - if keys["key"] != "${FOUNDRY_TOOL_CUSTOM_API_KEY}" { + if keys["key"] != "${TOOL_CUSTOM_API_KEY}" { t.Errorf("Expected env var ref for key, got '%v'", keys["key"]) } @@ -629,7 +629,7 @@ func TestExtractToolboxAndConnectionConfigs_CustomKeysCredentials(t *testing.T) if _, hasKeys := oauthConn.Credentials["keys"]; hasKeys { t.Error("OAuth2 connection should not have 'keys' wrapper") } - if oauthConn.Credentials["clientId"] != "${FOUNDRY_TOOL_OAUTH_TOOL_CLIENTID}" { + if oauthConn.Credentials["clientId"] != "${TOOL_OAUTH_TOOL_CLIENTID}" { t.Errorf("Expected flat clientId ref, got '%v'", oauthConn.Credentials["clientId"]) } } @@ -644,7 +644,7 @@ func TestInjectToolboxEnvVarsIntoDefinition_AddsEnvVars(t *testing.T) { Name: "my-agent", }, Protocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, @@ -678,10 +678,10 @@ func TestInjectToolboxEnvVarsIntoDefinition_AddsEnvVars(t *testing.T) { } // Toolbox env var is injected - if envVars[1].Name != "FOUNDRY_TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT" { + if envVars[1].Name != "TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT" { t.Errorf("Expected injected env var name, got %s", envVars[1].Name) } - if envVars[1].Value != "${FOUNDRY_TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT}" { + if envVars[1].Value != "${TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT}" { t.Errorf("Expected env var reference value, got %s", envVars[1].Value) } } @@ -696,10 +696,10 @@ func TestInjectToolboxEnvVarsIntoDefinition_SkipsExisting(t *testing.T) { Name: "my-agent", }, Protocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ - {Name: "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT", Value: "custom-value"}, + {Name: "TOOLBOX_MY_TOOLS_MCP_ENDPOINT", Value: "custom-value"}, }, }, Resources: []any{ @@ -739,7 +739,7 @@ func TestInjectToolboxEnvVarsIntoDefinition_MultipleToolboxes(t *testing.T) { Name: "my-agent", }, Protocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, }, Resources: []any{ @@ -762,10 +762,10 @@ func TestInjectToolboxEnvVarsIntoDefinition_MultipleToolboxes(t *testing.T) { if len(envVars) != 2 { t.Fatalf("Expected 2 env vars, got %d", len(envVars)) } - if envVars[0].Name != "FOUNDRY_TOOLBOX_SEARCH_TOOLS_MCP_ENDPOINT" { + if envVars[0].Name != "TOOLBOX_SEARCH_TOOLS_MCP_ENDPOINT" { t.Errorf("Expected first toolbox env var, got %s", envVars[0].Name) } - if envVars[1].Name != "FOUNDRY_TOOLBOX_GITHUB_TOOLS_MCP_ENDPOINT" { + if envVars[1].Name != "TOOLBOX_GITHUB_TOOLS_MCP_ENDPOINT" { t.Errorf("Expected second toolbox env var, got %s", envVars[1].Name) } } @@ -835,10 +835,10 @@ func TestToolboxMCPEndpointEnvKey(t *testing.T) { input string expected string }{ - {"simple", "my-tools", "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, - {"spaces", "my tools", "FOUNDRY_TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, - {"mixed", "agent-tools v2", "FOUNDRY_TOOLBOX_AGENT_TOOLS_V2_MCP_ENDPOINT"}, - {"already upper", "TOOLS", "FOUNDRY_TOOLBOX_TOOLS_MCP_ENDPOINT"}, + {"simple", "my-tools", "TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"spaces", "my tools", "TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"mixed", "agent-tools v2", "TOOLBOX_AGENT_TOOLS_V2_MCP_ENDPOINT"}, + {"already upper", "TOOLS", "TOOLBOX_TOOLS_MCP_ENDPOINT"}, } for _, tt := range tests { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index bc3ad2ce282..ac3a584fb71 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -282,7 +282,49 @@ func connectionsEnvUpdate( azdClient *azdext.AzdClient, envName string, ) error { - return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", connections) + // Strip credentials from the connections env var — Bicep's ConnectionConfig + // type doesn't include credentials (they're a separate @secure param). + // Including them causes "unable to deserialize request body" errors. + stripped := make([]project.Connection, len(connections)) + copy(stripped, connections) + for i := range stripped { + stripped[i].Credentials = nil + } + + if err := marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", stripped); err != nil { + return err + } + + return connectionCredentialsEnvUpdate(ctx, connections, azdClient, envName) +} + +// connectionCredentialsEnvUpdate builds a dictionary of connection name → credentials +// and serializes it to AI_PROJECT_CONNECTION_CREDENTIALS. +func connectionCredentialsEnvUpdate( + ctx context.Context, + connections []project.Connection, + azdClient *azdext.AzdClient, + envName string, +) error { + credMap := buildConnectionCredentials(connections) + if len(credMap) == 0 { + return nil + } + + return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTION_CREDENTIALS", credMap) +} + +// buildConnectionCredentials returns a map of connection name → credentials object +// for all connections that have non-empty credentials. +func buildConnectionCredentials(connections []project.Connection) map[string]map[string]any { + result := map[string]map[string]any{} + for _, conn := range connections { + if len(conn.Credentials) > 0 { + result[conn.Name] = conn.Credentials + } + } + + return result } // toolConnectionsEnvUpdate serializes tool connections to AI_PROJECT_TOOL_CONNECTIONS env var. @@ -596,7 +638,7 @@ func upsertToolset( ) } -// registerToolboxEnvVars sets FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT. +// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT. func registerToolboxEnvVars( ctx context.Context, azdClient *azdext.AzdClient, @@ -616,11 +658,11 @@ func registerToolboxEnvVars( ) } -// toolboxMCPEndpointEnvKey builds the FOUNDRY_TOOLBOX_{NAME}_MCP_ENDPOINT env var key. +// toolboxMCPEndpointEnvKey builds the TOOLBOX_{NAME}_MCP_ENDPOINT env var key. func toolboxMCPEndpointEnvKey(toolboxName string) string { key := strings.ReplaceAll(toolboxName, " ", "_") key = strings.ReplaceAll(key, "-", "_") - return fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", strings.ToUpper(key)) + return fmt.Sprintf("TOOLBOX_%s_MCP_ENDPOINT", strings.ToUpper(key)) } // resolveToolboxEnvVars resolves ${VAR} references in toolbox name, description, @@ -698,12 +740,22 @@ func resolveToolboxConnectionIDs( if connName == "" { continue } + connName = resolveTemplateRef(connName) if armID, ok := connIDs[connName]; ok { toolbox.Tools[i]["project_connection_id"] = armID } } } +// resolveTemplateRef strips {{ }} template wrapping and trims whitespace. +// "{{ my_conn }}" → "my_conn", "my_conn" → "my_conn" (unchanged). +func resolveTemplateRef(s string) string { + if strings.HasPrefix(s, "{{") && strings.HasSuffix(s, "}}") { + return strings.TrimSpace(s[2 : len(s)-2]) + } + return s +} + // resolveEnvValue resolves ${VAR} references in a string using envsubst. func resolveEnvValue(value string, azdEnv map[string]string) string { resolved, err := envsubst.Eval(value, func(varName string) string { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go index 81aeffbf02b..9131ce2b5bd 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go @@ -13,7 +13,7 @@ func TestNormalizeCredentials_CustomKeys_FlatToNested(t *testing.T) { t.Parallel() // Old-format flat credentials should be wrapped under "keys" - creds := map[string]any{"key": "${FOUNDRY_TOOL_CONTEXT7_KEY}"} + creds := map[string]any{"key": "${TOOL_CONTEXT7_KEY}"} result := normalizeCredentials("CustomKeys", creds) keysRaw, ok := result["keys"] @@ -24,7 +24,7 @@ func TestNormalizeCredentials_CustomKeys_FlatToNested(t *testing.T) { if !ok { t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) } - if keys["key"] != "${FOUNDRY_TOOL_CONTEXT7_KEY}" { + if keys["key"] != "${TOOL_CONTEXT7_KEY}" { t.Errorf("Expected key value preserved, got %v", keys["key"]) } } @@ -34,7 +34,7 @@ func TestNormalizeCredentials_CustomKeys_AlreadyNested(t *testing.T) { // Already-correct nested credentials should be returned as-is creds := map[string]any{ - "keys": map[string]any{"key": "${FOUNDRY_TOOL_CONTEXT7_KEY}"}, + "keys": map[string]any{"key": "${TOOL_CONTEXT7_KEY}"}, } result := normalizeCredentials("CustomKeys", creds) @@ -46,7 +46,7 @@ func TestNormalizeCredentials_CustomKeys_AlreadyNested(t *testing.T) { if !ok { t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) } - if keys["key"] != "${FOUNDRY_TOOL_CONTEXT7_KEY}" { + if keys["key"] != "${TOOL_CONTEXT7_KEY}" { t.Errorf("Expected key value preserved, got %v", keys["key"]) } if len(result) != 1 { @@ -150,8 +150,9 @@ func TestResolveToolboxConnectionIDs(t *testing.T) { Name: "test", Tools: []map[string]any{ {"type": "web_search"}, - {"type": "mcp", "project_connection_id": "github_mcp_connection"}, + {"type": "mcp", "project_connection_id": "{{ github_mcp_connection }}"}, {"type": "mcp", "project_connection_id": "unknown_conn"}, + {"type": "mcp", "project_connection_id": "github_mcp_connection"}, }, } @@ -162,7 +163,7 @@ func TestResolveToolboxConnectionIDs(t *testing.T) { t.Error("tool 0 should not have project_connection_id") } - // Known connection: replaced with ARM ID + // Template ref {{ name }}: resolved to ARM ID if toolbox.Tools[1]["project_connection_id"] != "/subscriptions/123/connections/github_mcp_connection" { t.Errorf("tool 1 project_connection_id = %v, want ARM ID", toolbox.Tools[1]["project_connection_id"]) @@ -173,4 +174,105 @@ func TestResolveToolboxConnectionIDs(t *testing.T) { t.Errorf("tool 2 project_connection_id = %v, want 'unknown_conn'", toolbox.Tools[2]["project_connection_id"]) } + + // Bare name (no braces): also resolved + if toolbox.Tools[3]["project_connection_id"] != "/subscriptions/123/connections/github_mcp_connection" { + t.Errorf("tool 3 project_connection_id = %v, want ARM ID", + toolbox.Tools[3]["project_connection_id"]) + } +} + +func TestResolveTemplateRef(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want string + }{ + {"{{ my_conn }}", "my_conn"}, + {"{{my_conn}}", "my_conn"}, + {"{{ spaced }}", "spaced"}, + {"my_conn", "my_conn"}, + {"", ""}, + {"{not_template}", "{not_template}"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + if got := resolveTemplateRef(tt.input); got != tt.want { + t.Errorf("resolveTemplateRef(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestBuildConnectionCredentials(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connections []project.Connection + wantKeys []string + wantEmpty bool + }{ + { + name: "empty connections", + wantEmpty: true, + }, + { + name: "connections with credentials", + connections: []project.Connection{ + { + Name: "my-openai", + Credentials: map[string]any{"key": "${OPENAI_API_KEY}"}, + }, + { + Name: "github-mcp", + Credentials: map[string]any{"pat": "${GITHUB_PAT}"}, + }, + }, + wantKeys: []string{"my-openai", "github-mcp"}, + }, + { + name: "skips connections without credentials", + connections: []project.Connection{ + { + Name: "no-creds", + Credentials: nil, + }, + { + Name: "has-creds", + Credentials: map[string]any{"secret": "val"}, + }, + }, + wantKeys: []string{"has-creds"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := buildConnectionCredentials(tt.connections) + + if tt.wantEmpty { + if len(result) != 0 { + t.Fatalf("expected empty map, got %v", result) + } + return + } + + if len(result) != len(tt.wantKeys) { + t.Fatalf("expected %d entries, got %d: %v", + len(tt.wantKeys), len(result), result) + } + + for _, key := range tt.wantKeys { + if _, ok := result[key]; !ok { + t.Errorf("expected key %q in result", key) + } + } + }) + } } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go index 0f62aede98d..b2e6b6ba578 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go @@ -418,7 +418,7 @@ func CreateHostedAgentAPIRequest(hostedAgent ContainerAgent, buildConfig *AgentB } else { // Set default protocol versions if none specified protocolVersions = []agent_api.ProtocolVersionRecord{ - {Protocol: agent_api.AgentProtocolResponses, Version: "v1"}, + {Protocol: agent_api.AgentProtocolResponses, Version: "1.0.0"}, } } diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index f2bc0682c60..a218204710f 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -657,7 +657,7 @@ func (p *AgentServiceTargetProvider) deployHostedAgent( projectEndpoint := strings.TrimRight(azdEnv["AZURE_AI_PROJECT_ENDPOINT"], "/") for _, toolbox := range foundryAgentConfig.Toolboxes { toolboxKey := p.getServiceKey(toolbox.Name) - envKey := fmt.Sprintf("FOUNDRY_TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) + envKey := fmt.Sprintf("TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) if _, exists := resolvedEnvVars[envKey]; !exists { resolvedEnvVars[envKey] = fmt.Sprintf( "%s/toolsets/%s/mcp", projectEndpoint, toolbox.Name, From 9154f4d939929e766aff24e27aa08e68986b803d Mon Sep 17 00:00:00 2001 From: trangevi Date: Wed, 8 Apr 2026 15:46:31 -0700 Subject: [PATCH 08/14] Add connection resource type for provisioning Signed-off-by: trangevi --- cli/azd/pkg/azapi/azure_resource_types.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cli/azd/pkg/azapi/azure_resource_types.go b/cli/azd/pkg/azapi/azure_resource_types.go index 9354854c0d0..57869e4ffab 100644 --- a/cli/azd/pkg/azapi/azure_resource_types.go +++ b/cli/azd/pkg/azapi/azure_resource_types.go @@ -57,6 +57,8 @@ const ( //nolint:lll AzureResourceTypeCognitiveServiceAccountProject AzureResourceType = "Microsoft.CognitiveServices/accounts/projects" //nolint:lll + AzureResourceTypeCognitiveServiceAccountProjectConnection AzureResourceType = "Microsoft.CognitiveServices/accounts/projects/connections" + //nolint:lll AzureResourceTypeCognitiveServiceAccountCapabilityHost AzureResourceType = "Microsoft.CognitiveServices/accounts/capabilityHosts" ) @@ -135,6 +137,8 @@ func GetResourceTypeDisplayName(resourceType AzureResourceType) string { return "Azure AI Services Model Deployment" case AzureResourceTypeCognitiveServiceAccountProject: return "Foundry project" + case AzureResourceTypeCognitiveServiceAccountProjectConnection: + return "Foundry project connection" case AzureResourceTypeCognitiveServiceAccountCapabilityHost: return "Foundry capability host" case AzureResourceTypeSearchService: From 3c7b319b05258ccec023b192caa9d11391512227 Mon Sep 17 00:00:00 2001 From: trangevi Date: Fri, 10 Apr 2026 09:13:26 -0700 Subject: [PATCH 09/14] Move to /toolboxes api Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/listen.go | 56 ++----- .../internal/exterrors/codes.go | 12 +- .../pkg/azure/foundry_toolsets_client.go | 143 +++++++----------- 3 files changed, 76 insertions(+), 135 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index ac3a584fb71..e0d1a5bbe15 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -6,9 +6,7 @@ package cmd import ( "context" "encoding/json" - "errors" "fmt" - "net/http" "os" "path/filepath" "strconv" @@ -19,7 +17,6 @@ import ( "azureaiagent/internal/pkg/azure" "azureaiagent/internal/project" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" @@ -536,7 +533,7 @@ func provisionToolboxes( ) } - toolsetsClient := azure.NewFoundryToolsetsClient( + toolboxClient := azure.NewFoundryToolboxClient( projectEndpoint, cred, ) @@ -569,8 +566,8 @@ func provisionToolboxes( // Replace project_connection_id friendly names with ARM resource IDs resolveToolboxConnectionIDs(&toolbox, connIDMap) - if err := upsertToolset( - ctx, toolsetsClient, toolbox, + if err := createToolboxVersion( + ctx, toolboxClient, toolbox, ); err != nil { return err } @@ -591,51 +588,26 @@ func provisionToolboxes( return nil } -// upsertToolset creates a toolset, or updates it if it already exists. -func upsertToolset( +// createToolboxVersion creates a new version of a toolbox. +// If the toolbox does not exist, it will be created automatically. +func createToolboxVersion( ctx context.Context, - client *azure.FoundryToolsetsClient, + client *azure.FoundryToolboxClient, toolbox project.Toolbox, ) error { - createReq := &azure.CreateToolsetRequest{ - Name: toolbox.Name, + req := &azure.CreateToolboxVersionRequest{ Description: toolbox.Description, Tools: toolbox.Tools, } - _, err := client.CreateToolset(ctx, createReq) - if err == nil { - return nil - } - - // 409 Conflict means the toolset already exists — update it - var respErr *azcore.ResponseError - if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { - fmt.Fprintf( - os.Stderr, - " Toolset '%s' already exists, updating...\n", - toolbox.Name, + if _, err := client.CreateToolboxVersion(ctx, toolbox.Name, req); err != nil { + return exterrors.Internal( + exterrors.CodeCreateToolboxVersionFailed, + fmt.Sprintf("failed to create toolbox version '%s': %s", toolbox.Name, err), ) - updateReq := &azure.UpdateToolsetRequest{ - Description: toolbox.Description, - Tools: toolbox.Tools, - } - if _, updateErr := client.UpdateToolset(ctx, toolbox.Name, updateReq); updateErr != nil { - return exterrors.Internal( - exterrors.CodeCreateToolsetFailed, - fmt.Sprintf("failed to update toolset '%s': %s", toolbox.Name, updateErr), - ) - } - return nil } - return exterrors.Internal( - exterrors.CodeCreateToolsetFailed, - fmt.Sprintf( - "failed to create toolset '%s': %s", - toolbox.Name, err, - ), - ) + return nil } // registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT. @@ -650,7 +622,7 @@ func registerToolboxEnvVars( endpoint := strings.TrimRight(projectEndpoint, "/") mcpEndpoint := fmt.Sprintf( - "%s/toolsets/%s/mcp", endpoint, toolboxName, + "%s/toolboxes/%s/mcp?api-version=v1", endpoint, toolboxName, ) return setEnvVar( diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index a1c0fc73387..bfbb7101fe9 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -88,11 +88,10 @@ const ( CodeInvalidFilePath = "invalid_file_path" ) -// Error codes for toolbox/toolset operations. +// Error codes for toolbox operations. const ( - CodeInvalidToolbox = "invalid_toolbox" - CodeCreateToolsetFailed = "create_toolset_failed" - CodeUpdateToolsetFailed = "update_toolset_failed" + CodeInvalidToolbox = "invalid_toolbox" + CodeCreateToolboxVersionFailed = "create_toolbox_version_failed" ) // Error codes for connection operations. @@ -123,7 +122,6 @@ const ( OpCreateAgent = "create_agent" OpStartContainer = "start_container" OpGetContainerOperation = "get_container_operation" - OpCreateToolset = "create_toolset" - OpUpdateToolset = "update_toolset" - OpGetToolset = "get_toolset" + OpCreateToolboxVersion = "create_toolbox_version" + OpGetToolbox = "get_toolbox" ) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go index 2d6b2d6e6b2..b18ea91356c 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go @@ -21,26 +21,26 @@ import ( ) const ( - toolsetsApiVersion = "v1" - toolsetsFeatureHeader = "Toolsets=V1Preview" + toolboxesApiVersion = "v1" + toolboxesFeatureHeader = "Toolsets=V1Preview" ) -// FoundryToolsetsClient provides methods for interacting with the Foundry Toolsets API -type FoundryToolsetsClient struct { +// FoundryToolboxClient provides methods for interacting with the Foundry Toolboxes API. +type FoundryToolboxClient struct { endpoint string pipeline runtime.Pipeline } -// NewFoundryToolsetsClient creates a new FoundryToolsetsClient -func NewFoundryToolsetsClient( +// NewFoundryToolboxClient creates a new FoundryToolboxClient. +func NewFoundryToolboxClient( endpoint string, cred azcore.TokenCredential, -) *FoundryToolsetsClient { +) *FoundryToolboxClient { userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) clientOptions := &policy.ClientOptions{ Logging: policy.LogOptions{ - AllowedHeaders: []string{azsdk.MsCorrelationIdHeader}, + AllowedHeaders: []string{azsdk.MsCorrelationIdHeader, "X-Request-Id"}, IncludeBody: true, }, PerCallPolicies: []policy.Policy{ @@ -57,54 +57,48 @@ func NewFoundryToolsetsClient( clientOptions, ) - return &FoundryToolsetsClient{ + return &FoundryToolboxClient{ endpoint: endpoint, pipeline: pipeline, } } -// CreateToolsetRequest is the request body for creating a toolset -type CreateToolsetRequest struct { - Name string `json:"name"` +// CreateToolboxVersionRequest is the request body for creating a new toolbox version. +// The toolbox name is provided in the URL path, not in the body. +type CreateToolboxVersionRequest struct { Description string `json:"description,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` Tools []map[string]any `json:"tools"` } -// UpdateToolsetRequest is the request body for updating a toolset -type UpdateToolsetRequest struct { - Description string `json:"description,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` - Tools []map[string]any `json:"tools"` +// ToolboxObject is the lightweight response for a toolbox (no tools list). +type ToolboxObject struct { + Id string `json:"id"` + Name string `json:"name"` + DefaultVersion string `json:"default_version"` } -// ToolsetObject is the response object for a toolset -type ToolsetObject struct { - Object string `json:"object"` +// ToolboxVersionObject is the response for a specific toolbox version. +type ToolboxVersionObject struct { Id string `json:"id"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` Name string `json:"name"` + Version string `json:"version"` Description string `json:"description,omitempty"` + CreatedAt int64 `json:"created_at"` Metadata map[string]string `json:"metadata,omitempty"` Tools []map[string]any `json:"tools"` } -// DeleteToolsetResponse is the response for deleting a toolset -type DeleteToolsetResponse struct { - Object string `json:"object"` - Name string `json:"name"` - Deleted bool `json:"deleted"` -} - -// CreateToolset creates a new toolset -func (c *FoundryToolsetsClient) CreateToolset( +// CreateToolboxVersion creates a new version of a toolbox. +// If the toolbox does not exist, it will be created automatically. +func (c *FoundryToolboxClient) CreateToolboxVersion( ctx context.Context, - request *CreateToolsetRequest, -) (*ToolsetObject, error) { + toolboxName string, + request *CreateToolboxVersionRequest, +) (*ToolboxVersionObject, error) { targetUrl := fmt.Sprintf( - "%s/toolsets?api-version=%s", - c.endpoint, toolsetsApiVersion, + "%s/toolboxes/%s/versions?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, ) payload, err := json.Marshal(request) @@ -117,7 +111,7 @@ func (c *FoundryToolsetsClient) CreateToolset( return nil, fmt.Errorf("failed to create request: %w", err) } - req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) if err := req.SetBody( streaming.NopCloser(bytes.NewReader(payload)), @@ -141,43 +135,30 @@ func (c *FoundryToolsetsClient) CreateToolset( return nil, fmt.Errorf("failed to read response body: %w", err) } - var toolset ToolsetObject - if err := json.Unmarshal(body, &toolset); err != nil { + var result ToolboxVersionObject + if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - return &toolset, nil + return &result, nil } -// UpdateToolset updates an existing toolset -func (c *FoundryToolsetsClient) UpdateToolset( +// GetToolbox retrieves a toolbox by name. +func (c *FoundryToolboxClient) GetToolbox( ctx context.Context, - toolsetName string, - request *UpdateToolsetRequest, -) (*ToolsetObject, error) { + toolboxName string, +) (*ToolboxObject, error) { targetUrl := fmt.Sprintf( - "%s/toolsets/%s?api-version=%s", - c.endpoint, url.PathEscape(toolsetName), toolsetsApiVersion, + "%s/toolboxes/%s?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, ) - payload, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := runtime.NewRequest(ctx, http.MethodPost, targetUrl) + req, err := runtime.NewRequest(ctx, http.MethodGet, targetUrl) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) - - if err := req.SetBody( - streaming.NopCloser(bytes.NewReader(payload)), - "application/json", - ); err != nil { - return nil, fmt.Errorf("failed to set request body: %w", err) - } + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) resp, err := c.pipeline.Do(req) if err != nil { @@ -194,50 +175,40 @@ func (c *FoundryToolsetsClient) UpdateToolset( return nil, fmt.Errorf("failed to read response body: %w", err) } - var toolset ToolsetObject - if err := json.Unmarshal(body, &toolset); err != nil { + var result ToolboxObject + if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - return &toolset, nil + return &result, nil } -// GetToolset retrieves a toolset by name -func (c *FoundryToolsetsClient) GetToolset( +// DeleteToolbox deletes a toolbox and all its versions. +func (c *FoundryToolboxClient) DeleteToolbox( ctx context.Context, - toolsetName string, -) (*ToolsetObject, error) { + toolboxName string, +) error { targetUrl := fmt.Sprintf( - "%s/toolsets/%s?api-version=%s", - c.endpoint, url.PathEscape(toolsetName), toolsetsApiVersion, + "%s/toolboxes/%s?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, ) - req, err := runtime.NewRequest(ctx, http.MethodGet, targetUrl) + req, err := runtime.NewRequest(ctx, http.MethodDelete, targetUrl) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return fmt.Errorf("failed to create request: %w", err) } - req.Raw().Header.Set("Foundry-Features", toolsetsFeatureHeader) + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) resp, err := c.pipeline.Do(req) if err != nil { - return nil, fmt.Errorf("HTTP request failed: %w", err) + return fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() - if !runtime.HasStatusCode(resp, http.StatusOK) { - return nil, runtime.NewResponseError(resp) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - var toolset ToolsetObject - if err := json.Unmarshal(body, &toolset); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusNoContent) { + return runtime.NewResponseError(resp) } - return &toolset, nil + return nil } From 9ec878dea50b9c467588c2d6ca9f22361ef1378a Mon Sep 17 00:00:00 2001 From: trangevi Date: Fri, 10 Apr 2026 11:00:22 -0700 Subject: [PATCH 10/14] Add version Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/listen.go | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index e0d1a5bbe15..ab7a2cd8ba8 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -566,16 +566,17 @@ func provisionToolboxes( // Replace project_connection_id friendly names with ARM resource IDs resolveToolboxConnectionIDs(&toolbox, connIDMap) - if err := createToolboxVersion( + version, err := createToolboxVersion( ctx, toolboxClient, toolbox, - ); err != nil { + ) + if err != nil { return err } if err := registerToolboxEnvVars( ctx, azdClient, currentEnv.Environment.Name, - projectEndpoint, toolbox.Name, + projectEndpoint, toolbox.Name, version, ); err != nil { return err } @@ -590,39 +591,43 @@ func provisionToolboxes( // createToolboxVersion creates a new version of a toolbox. // If the toolbox does not exist, it will be created automatically. +// Returns the version identifier of the newly created version. func createToolboxVersion( ctx context.Context, client *azure.FoundryToolboxClient, toolbox project.Toolbox, -) error { +) (string, error) { req := &azure.CreateToolboxVersionRequest{ Description: toolbox.Description, Tools: toolbox.Tools, } - if _, err := client.CreateToolboxVersion(ctx, toolbox.Name, req); err != nil { - return exterrors.Internal( + result, err := client.CreateToolboxVersion(ctx, toolbox.Name, req) + if err != nil { + return "", exterrors.Internal( exterrors.CodeCreateToolboxVersionFailed, fmt.Sprintf("failed to create toolbox version '%s': %s", toolbox.Name, err), ) } - return nil + return result.Version, nil } -// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT. +// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT with the versioned URL. func registerToolboxEnvVars( ctx context.Context, azdClient *azdext.AzdClient, envName string, projectEndpoint string, toolboxName string, + toolboxVersion string, ) error { envKey := toolboxMCPEndpointEnvKey(toolboxName) endpoint := strings.TrimRight(projectEndpoint, "/") mcpEndpoint := fmt.Sprintf( - "%s/toolboxes/%s/mcp?api-version=v1", endpoint, toolboxName, + "%s/toolboxes/%s/versions/%s/mcp?api-version=v1", + endpoint, toolboxName, toolboxVersion, ) return setEnvVar( From af85a8fd218444b2b0521080483ef4af42bbcd61 Mon Sep 17 00:00:00 2001 From: Glenn Harper <64209257+glharper@users.noreply.github.com> Date: Fri, 10 Apr 2026 14:18:12 -0400 Subject: [PATCH 11/14] Add unit tests and testdata for azure.ai.agents extension (#7634) * Add unit tests and testdata for azure.ai.agents extension Add 86 new unit tests across 5 previously untested or undertested packages in the azure.ai.agents extension, raising total test count from 183 to 269. Coverage improvements: - agent_yaml: 23.1% -> 53.8% (map.go YAML-to-API mapping fully tested) - registry_api: 0% -> 28.8% (tool conversion, parameter conversion, merge) - agent_api: 0% -> tested (JSON round-trip for all model types) - cmd: 23.0% -> 23.6% (copyDirectory, copyFile, buildAgentEndpoint) New test files: - agent_yaml/map_test.go: 44 tests for YAML-to-API transform functions - registry_api/helpers_test.go: 35 tests for pure conversion helpers - agent_api/models_test.go: 24 JSON serialization round-trip tests - cmd/init_copy_test.go: directory/file copy logic tests - cmd/agent_context_test.go: endpoint construction test - agent_yaml/testdata_test.go: fixture-based parsing + regression tests New testdata fixtures (7 YAML files): - 3 valid agents (minimal prompt, full prompt, hosted) - 1 MCP tools agent - 3 invalid manifests (no kind, no model, empty template) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * refactor: replace ptr[T] helper with Go 1.26 new(val) in tests Replace the generic ptr[T](v T) *T helper function with Go 1.26's built-in new(val) pattern in models_test.go and helpers_test.go, consistent with map_test.go and AGENTS.md conventions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * remove outdated comment --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../internal/cmd/agent_context_test.go | 23 + .../internal/cmd/init_copy_test.go | 119 ++ .../pkg/agents/agent_api/models_test.go | 1036 ++++++++++++++ .../pkg/agents/agent_yaml/map_test.go | 1199 +++++++++++++++++ .../agent_yaml/testdata/hosted-agent.yaml | 9 + .../testdata/invalid-empty-template.yaml | 1 + .../agent_yaml/testdata/invalid-no-kind.yaml | 4 + .../agent_yaml/testdata/invalid-no-model.yaml | 4 + .../agent_yaml/testdata/mcp-tools-agent.yaml | 18 + .../testdata/prompt-agent-full.yaml | 39 + .../testdata/prompt-agent-minimal.yaml | 5 + .../pkg/agents/agent_yaml/testdata_test.go | 286 ++++ .../pkg/agents/registry_api/helpers_test.go | 863 ++++++++++++ 13 files changed, 3606 insertions(+) create mode 100644 cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go new file mode 100644 index 00000000000..3df51697171 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import "testing" + +func TestBuildAgentEndpoint_Cases(t *testing.T) { + t.Parallel() + tests := []struct{ account, project, want string }{ + {"myaccount", "myproject", "https://myaccount.services.ai.azure.com/api/projects/myproject"}, + {"a", "b", "https://a.services.ai.azure.com/api/projects/b"}, + } + for _, tt := range tests { + t.Run(tt.account+"/"+tt.project, func(t *testing.T) { + t.Parallel() + got := buildAgentEndpoint(tt.account, tt.project) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go new file mode 100644 index 00000000000..c6240acacf5 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestCopyDirectory(t *testing.T) { + t.Parallel() + + t.Run("happy_path", func(t *testing.T) { + t.Parallel() + src := t.TempDir() + + // Create a small tree: file.txt, sub/nested.txt + if err := os.WriteFile(filepath.Join(src, "file.txt"), []byte("hello"), 0644); err != nil { + t.Fatal(err) + } + subDir := filepath.Join(src, "sub") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(subDir, "nested.txt"), []byte("world"), 0644); err != nil { + t.Fatal(err) + } + + dst := filepath.Join(t.TempDir(), "out") + if err := copyDirectory(src, dst); err != nil { + t.Fatal(err) + } + + // Verify top-level file + assertFileContents(t, filepath.Join(dst, "file.txt"), "hello") + // Verify nested file + assertFileContents(t, filepath.Join(dst, "sub", "nested.txt"), "world") + }) + + t.Run("same_path_noop", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + if err := copyDirectory(dir, dir); err != nil { + t.Fatalf("expected nil error for same path, got %v", err) + } + }) + + t.Run("subpath_error", func(t *testing.T) { + t.Parallel() + src := t.TempDir() + dst := filepath.Join(src, "child") + if err := os.MkdirAll(dst, 0755); err != nil { + t.Fatal(err) + } + + err := copyDirectory(src, dst) + if err == nil { + t.Fatal("expected error when dst is subpath of src") + } + if !strings.Contains(err.Error(), "refusing to copy") { + t.Errorf("unexpected error message: %v", err) + } + }) + + t.Run("missing_source_error", func(t *testing.T) { + t.Parallel() + src := filepath.Join(t.TempDir(), "nonexistent") + dst := t.TempDir() + + err := copyDirectory(src, dst) + if err == nil { + t.Fatal("expected error for missing source") + } + }) +} + +func TestCopyFile(t *testing.T) { + t.Parallel() + + t.Run("happy_path", func(t *testing.T) { + t.Parallel() + src := filepath.Join(t.TempDir(), "src.txt") + if err := os.WriteFile(src, []byte("data"), 0644); err != nil { + t.Fatal(err) + } + + dst := filepath.Join(t.TempDir(), "dst.txt") + if err := copyFile(src, dst); err != nil { + t.Fatal(err) + } + assertFileContents(t, dst, "data") + }) + + t.Run("missing_source_error", func(t *testing.T) { + t.Parallel() + src := filepath.Join(t.TempDir(), "nope.txt") + dst := filepath.Join(t.TempDir(), "dst.txt") + + if err := copyFile(src, dst); err == nil { + t.Fatal("expected error for missing source file") + } + }) +} + +// assertFileContents is a test helper that reads a file and compares its contents. +func assertFileContents(t *testing.T, path, want string) { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading %s: %v", path, err) + } + if got := string(data); got != want { + t.Errorf("file %s: got %q, want %q", path, got, want) + } +} + diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go new file mode 100644 index 00000000000..487d46305ae --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go @@ -0,0 +1,1036 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agent_api + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestCreateAgentRequest_RoundTrip(t *testing.T) { + t.Parallel() + + original := CreateAgentRequest{ + Name: "test-agent", + CreateAgentVersionRequest: CreateAgentVersionRequest{ + Description: new("A test agent"), + Metadata: map[string]string{"env": "test"}, + Definition: PromptAgentDefinition{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + RaiConfig: &RaiConfig{RaiPolicyName: "default"}, + }, + Model: "gpt-4o", + Instructions: new("You are helpful"), + Temperature: new(float32(0.7)), + TopP: new(float32(0.9)), + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + // Verify JSON tag names + s := string(data) + for _, field := range []string{`"name"`, `"description"`, `"metadata"`, `"definition"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s, got: %s", field, s) + } + } + + var got CreateAgentRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Name != original.Name { + t.Errorf("Name = %q, want %q", got.Name, original.Name) + } + if got.Description == nil || *got.Description != *original.Description { + t.Errorf("Description mismatch") + } + if got.Metadata["env"] != "test" { + t.Errorf("Metadata[env] = %q, want %q", got.Metadata["env"], "test") + } +} + +func TestAgentObject_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentObject{ + Object: "agent", + ID: "agent-123", + Name: "my-agent", + Versions: struct { + Latest AgentVersionObject `json:"latest"` + }{ + Latest: AgentVersionObject{ + Object: "agent_version", + ID: "ver-1", + Name: "my-agent", + Version: "1", + Description: new("version one"), + Metadata: map[string]string{"release": "stable"}, + CreatedAt: 1700000000, + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"object"`, `"id"`, `"name"`, `"versions"`, `"latest"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentObject + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.ID != original.ID { + t.Errorf("ID = %q, want %q", got.ID, original.ID) + } + if got.Versions.Latest.Version != "1" { + t.Errorf("Latest.Version = %q, want %q", got.Versions.Latest.Version, "1") + } + if got.Versions.Latest.CreatedAt != 1700000000 { + t.Errorf("Latest.CreatedAt = %d, want %d", got.Versions.Latest.CreatedAt, int64(1700000000)) + } +} + +func TestAgentContainerObject_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentContainerObject{ + Object: "container", + ID: "ctr-1", + Status: AgentContainerStatusRunning, + MaxReplicas: new(int32(3)), + MinReplicas: new(int32(1)), + ErrorMessage: new("partial failure"), + CreatedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-06-01T00:00:00Z", + Container: &AgentContainerDetails{ + HealthState: "healthy", + ProvisioningState: "Succeeded", + State: "Running", + UpdatedOn: "2024-06-01T00:00:00Z", + Replicas: []AgentContainerReplicaState{ + {Name: "replica-0", State: "Running", ContainerState: "started"}, + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{ + `"max_replicas"`, `"min_replicas"`, `"error_message"`, + `"created_at"`, `"updated_at"`, `"container"`, + `"health_state"`, `"provisioning_state"`, `"container_state"`, + } { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentContainerObject + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Status != AgentContainerStatusRunning { + t.Errorf("Status = %q, want %q", got.Status, AgentContainerStatusRunning) + } + if got.MaxReplicas == nil || *got.MaxReplicas != 3 { + t.Error("MaxReplicas mismatch") + } + if got.MinReplicas == nil || *got.MinReplicas != 1 { + t.Error("MinReplicas mismatch") + } + if got.ErrorMessage == nil || *got.ErrorMessage != "partial failure" { + t.Error("ErrorMessage mismatch") + } + if got.Container == nil || len(got.Container.Replicas) != 1 { + t.Error("Container.Replicas mismatch") + } +} + +func TestPromptAgentDefinition_RoundTrip(t *testing.T) { + t.Parallel() + + original := PromptAgentDefinition{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + RaiConfig: &RaiConfig{RaiPolicyName: "strict"}, + }, + Model: "gpt-4o", + Instructions: new("Be concise"), + Temperature: new(float32(0.5)), + TopP: new(float32(0.95)), + Reasoning: &Reasoning{Effort: "high"}, + Text: &ResponseTextFormatConfiguration{Type: "text"}, + StructuredInputs: map[string]StructuredInputDefinition{ + "query": { + Description: new("user query"), + Required: new(true), + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{ + `"kind"`, `"model"`, `"instructions"`, `"temperature"`, + `"top_p"`, `"reasoning"`, `"text"`, `"structured_inputs"`, + `"rai_config"`, `"rai_policy_name"`, + } { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got PromptAgentDefinition + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Kind != AgentKindPrompt { + t.Errorf("Kind = %q, want %q", got.Kind, AgentKindPrompt) + } + if got.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", got.Model, "gpt-4o") + } + if got.Instructions == nil || *got.Instructions != "Be concise" { + t.Error("Instructions mismatch") + } + if got.Temperature == nil || *got.Temperature != 0.5 { + t.Error("Temperature mismatch") + } + if got.Reasoning == nil || got.Reasoning.Effort != "high" { + t.Error("Reasoning mismatch") + } + if si, ok := got.StructuredInputs["query"]; !ok || si.Description == nil || *si.Description != "user query" { + t.Error("StructuredInputs mismatch") + } +} + +func TestHostedAgentDefinition_RoundTrip(t *testing.T) { + t.Parallel() + + original := HostedAgentDefinition{ + AgentDefinition: AgentDefinition{Kind: AgentKindHosted}, + ContainerProtocolVersions: []ProtocolVersionRecord{ + {Protocol: AgentProtocolResponses, Version: "2024-07-01"}, + }, + CPU: "1.0", + Memory: "2Gi", + EnvironmentVariables: map[string]string{"LOG_LEVEL": "debug"}, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{ + `"container_protocol_versions"`, `"cpu"`, `"memory"`, `"environment_variables"`, + } { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got HostedAgentDefinition + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Kind != AgentKindHosted { + t.Errorf("Kind = %q, want %q", got.Kind, AgentKindHosted) + } + if len(got.ContainerProtocolVersions) != 1 || got.ContainerProtocolVersions[0].Version != "2024-07-01" { + t.Error("ContainerProtocolVersions mismatch") + } + if got.EnvironmentVariables["LOG_LEVEL"] != "debug" { + t.Error("EnvironmentVariables mismatch") + } +} + +func TestImageBasedHostedAgentDefinition_RoundTrip(t *testing.T) { + t.Parallel() + + original := ImageBasedHostedAgentDefinition{ + HostedAgentDefinition: HostedAgentDefinition{ + AgentDefinition: AgentDefinition{Kind: AgentKindHosted}, + ContainerProtocolVersions: []ProtocolVersionRecord{ + {Protocol: AgentProtocolActivityProtocol, Version: "1.0"}, + }, + CPU: "0.5", + Memory: "1Gi", + }, + Image: "myregistry.azurecr.io/agent:latest", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"image"`) { + t.Error("expected JSON to contain \"image\"") + } + + var got ImageBasedHostedAgentDefinition + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Image != original.Image { + t.Errorf("Image = %q, want %q", got.Image, original.Image) + } + if got.CPU != "0.5" { + t.Errorf("CPU = %q, want %q", got.CPU, "0.5") + } +} + +func TestAgentVersionObject_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentVersionObject{ + Object: "agent_version", + ID: "ver-abc", + Name: "my-agent", + Version: "3", + Description: new("third version"), + Metadata: map[string]string{"stage": "prod"}, + CreatedAt: 1710000000, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"object"`, `"id"`, `"version"`, `"created_at"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentVersionObject + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Version != "3" { + t.Errorf("Version = %q, want %q", got.Version, "3") + } + if got.CreatedAt != 1710000000 { + t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, int64(1710000000)) + } + if got.Metadata["stage"] != "prod" { + t.Errorf("Metadata[stage] = %q, want %q", got.Metadata["stage"], "prod") + } +} + +func TestDeleteAgentResponse_RoundTrip(t *testing.T) { + t.Parallel() + + original := DeleteAgentResponse{ + Object: "agent", + Name: "old-agent", + Deleted: true, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got DeleteAgentResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Name != "old-agent" { + t.Errorf("Name = %q, want %q", got.Name, "old-agent") + } + if !got.Deleted { + t.Error("Deleted = false, want true") + } +} + +func TestDeleteAgentVersionResponse_RoundTrip(t *testing.T) { + t.Parallel() + + original := DeleteAgentVersionResponse{ + Object: "agent_version", + Name: "my-agent", + Version: "2", + Deleted: true, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"version"`) { + t.Error("expected JSON to contain \"version\"") + } + + var got DeleteAgentVersionResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Version != "2" { + t.Errorf("Version = %q, want %q", got.Version, "2") + } + if !got.Deleted { + t.Error("Deleted = false, want true") + } +} + +func TestAgentEventHandlerRequest_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentEventHandlerRequest{ + Name: "eval-handler", + Metadata: map[string]string{"purpose": "eval"}, + EventTypes: []AgentEventType{AgentEventTypeResponseCompleted}, + Filter: &AgentEventHandlerFilter{ + AgentVersions: []string{"v1", "v2"}, + }, + Destination: AgentEventHandlerDestination{ + Type: AgentEventHandlerDestinationTypeEvals, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"event_types"`, `"filter"`, `"destination"`, `"agent_versions"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentEventHandlerRequest + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Name != "eval-handler" { + t.Errorf("Name = %q, want %q", got.Name, "eval-handler") + } + if len(got.EventTypes) != 1 || got.EventTypes[0] != AgentEventTypeResponseCompleted { + t.Error("EventTypes mismatch") + } + if got.Filter == nil || len(got.Filter.AgentVersions) != 2 { + t.Error("Filter.AgentVersions mismatch") + } +} + +func TestAgentEventHandlerObject_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentEventHandlerObject{ + Object: "event_handler", + ID: "eh-1", + Name: "my-handler", + Metadata: map[string]string{"team": "platform"}, + CreatedAt: 1720000000, + EventTypes: []AgentEventType{AgentEventTypeResponseCompleted}, + Destination: AgentEventHandlerDestination{ + Type: AgentEventHandlerDestinationTypeEvals, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var got AgentEventHandlerObject + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.ID != "eh-1" { + t.Errorf("ID = %q, want %q", got.ID, "eh-1") + } + if got.CreatedAt != 1720000000 { + t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, int64(1720000000)) + } +} + +func TestFunctionTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := FunctionTool{ + Tool: Tool{Type: ToolTypeFunction}, + Name: "get_weather", + Description: new("Gets weather data"), + Parameters: map[string]any{"type": "object"}, + Strict: new(true), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"type"`, `"name"`, `"description"`, `"parameters"`, `"strict"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got FunctionTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != ToolTypeFunction { + t.Errorf("Type = %q, want %q", got.Type, ToolTypeFunction) + } + if got.Name != "get_weather" { + t.Errorf("Name = %q, want %q", got.Name, "get_weather") + } + if got.Strict == nil || !*got.Strict { + t.Error("Strict mismatch") + } +} + +func TestMCPTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := MCPTool{ + Tool: Tool{Type: ToolTypeMCP}, + ServerLabel: "my-server", + ServerURL: "https://mcp.example.com", + Headers: map[string]string{"Authorization": "Bearer tok"}, + ProjectConnectionID: new("conn-abc"), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"server_label"`, `"server_url"`, `"project_connection_id"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got MCPTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.ServerLabel != "my-server" { + t.Errorf("ServerLabel = %q, want %q", got.ServerLabel, "my-server") + } + if got.ServerURL != "https://mcp.example.com" { + t.Errorf("ServerURL = %q, want %q", got.ServerURL, "https://mcp.example.com") + } + if got.ProjectConnectionID == nil || *got.ProjectConnectionID != "conn-abc" { + t.Error("ProjectConnectionID mismatch") + } +} + +func TestFileSearchTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := FileSearchTool{ + Tool: Tool{Type: ToolTypeFileSearch}, + VectorStoreIds: []string{"vs-1", "vs-2"}, + MaxNumResults: new(int32(10)), + RankingOptions: &RankingOptions{ + Ranker: new("auto"), + ScoreThreshold: new(float32(0.8)), + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{ + `"vector_store_ids"`, `"max_num_results"`, `"ranking_options"`, + `"ranker"`, `"score_threshold"`, + } { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got FileSearchTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.VectorStoreIds) != 2 { + t.Errorf("VectorStoreIds length = %d, want 2", len(got.VectorStoreIds)) + } + if got.MaxNumResults == nil || *got.MaxNumResults != 10 { + t.Error("MaxNumResults mismatch") + } + if got.RankingOptions == nil || got.RankingOptions.Ranker == nil || *got.RankingOptions.Ranker != "auto" { + t.Error("RankingOptions.Ranker mismatch") + } +} + +func TestWebSearchPreviewTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := WebSearchPreviewTool{ + Tool: Tool{Type: ToolTypeWebSearchPreview}, + SearchContextSize: new("medium"), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"search_context_size"`) { + t.Error("expected JSON to contain \"search_context_size\"") + } + + var got WebSearchPreviewTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != ToolTypeWebSearchPreview { + t.Errorf("Type = %q, want %q", got.Type, ToolTypeWebSearchPreview) + } + if got.SearchContextSize == nil || *got.SearchContextSize != "medium" { + t.Error("SearchContextSize mismatch") + } +} + +func TestCodeInterpreterTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := CodeInterpreterTool{ + Tool: Tool{Type: ToolTypeCodeInterpreter}, + Container: "container-id-123", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"container"`) { + t.Error("expected JSON to contain \"container\"") + } + + var got CodeInterpreterTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Type != ToolTypeCodeInterpreter { + t.Errorf("Type = %q, want %q", got.Type, ToolTypeCodeInterpreter) + } + // Container is `any`, so after round-trip it comes back as string + if got.Container != "container-id-123" { + t.Errorf("Container = %v, want %q", got.Container, "container-id-123") + } +} + +func TestBingGroundingAgentTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := BingGroundingAgentTool{ + Tool: Tool{Type: ToolTypeBingGrounding}, + BingGrounding: BingGroundingSearchToolParameters{ + ProjectConnections: ToolProjectConnectionList{ + ProjectConnections: []ToolProjectConnection{{ID: "conn-1"}}, + }, + SearchConfigurations: []BingGroundingSearchConfiguration{ + { + ProjectConnectionID: "conn-1", + Market: new("en-US"), + }, + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"bing_grounding"`) { + t.Error("expected JSON to contain \"bing_grounding\"") + } + if !strings.Contains(s, `"project_connections"`) { + t.Error("expected JSON to contain \"project_connections\"") + } + + var got BingGroundingAgentTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.BingGrounding.ProjectConnections.ProjectConnections) != 1 { + t.Error("ProjectConnections length mismatch") + } + if len(got.BingGrounding.SearchConfigurations) != 1 { + t.Error("SearchConfigurations length mismatch") + } +} + +func TestOpenApiAgentTool_RoundTrip(t *testing.T) { + t.Parallel() + + original := OpenApiAgentTool{ + Tool: Tool{Type: ToolTypeOpenAPI}, + OpenAPI: OpenApiFunctionDefinition{ + Name: "petstore", + Description: new("Pet store API"), + Spec: map[string]any{"openapi": "3.0.0"}, + Auth: OpenApiAuthDetails{ + Type: OpenApiAuthTypeAnonymous, + }, + DefaultParams: []string{"api_version=v1"}, + Functions: []OpenApiFunction{ + { + Name: "listPets", + Description: new("List all pets"), + Parameters: map[string]any{"type": "object"}, + }, + }, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"openapi"`) { + t.Error("expected JSON to contain \"openapi\"") + } + + var got OpenApiAgentTool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.OpenAPI.Name != "petstore" { + t.Errorf("OpenAPI.Name = %q, want %q", got.OpenAPI.Name, "petstore") + } + if got.OpenAPI.Auth.Type != OpenApiAuthTypeAnonymous { + t.Errorf("Auth.Type = %q, want %q", got.OpenAPI.Auth.Type, OpenApiAuthTypeAnonymous) + } + if len(got.OpenAPI.Functions) != 1 { + t.Errorf("Functions length = %d, want 1", len(got.OpenAPI.Functions)) + } +} + +func TestSessionFileInfo_RoundTrip(t *testing.T) { + t.Parallel() + + original := SessionFileInfo{ + Name: "data.csv", + Path: "/workspace/data.csv", + IsDirectory: false, + Size: 2048, + Mode: 0644, + LastModified: new("2024-06-15T10:30:00Z"), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"name"`, `"path"`, `"is_dir"`, `"size"`, `"mode"`, `"modified_time"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got SessionFileInfo + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Name != "data.csv" { + t.Errorf("Name = %q, want %q", got.Name, "data.csv") + } + if got.IsDirectory { + t.Error("IsDirectory = true, want false") + } + if got.Size != 2048 { + t.Errorf("Size = %d, want %d", got.Size, int64(2048)) + } + if got.LastModified == nil || *got.LastModified != "2024-06-15T10:30:00Z" { + t.Error("LastModified mismatch") + } +} + +func TestSessionFileList_RoundTrip(t *testing.T) { + t.Parallel() + + original := SessionFileList{ + Path: "/workspace", + Entries: []SessionFileInfo{ + {Name: "file1.txt", Path: "/workspace/file1.txt", IsDirectory: false, Size: 100}, + {Name: "subdir", Path: "/workspace/subdir", IsDirectory: true}, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"entries"`) { + t.Error("expected JSON to contain \"entries\"") + } + + var got SessionFileList + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Path != "/workspace" { + t.Errorf("Path = %q, want %q", got.Path, "/workspace") + } + if len(got.Entries) != 2 { + t.Fatalf("Entries length = %d, want 2", len(got.Entries)) + } + if !got.Entries[1].IsDirectory { + t.Error("Entries[1].IsDirectory = false, want true") + } +} + +func TestEvalsDestination_RoundTrip(t *testing.T) { + t.Parallel() + + original := EvalsDestination{ + AgentEventHandlerDestination: AgentEventHandlerDestination{ + Type: AgentEventHandlerDestinationTypeEvals, + }, + EvalID: "eval-123", + MaxHourlyRuns: new(int32(10)), + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"eval_id"`, `"max_hourly_runs"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got EvalsDestination + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.EvalID != "eval-123" { + t.Errorf("EvalID = %q, want %q", got.EvalID, "eval-123") + } + if got.MaxHourlyRuns == nil || *got.MaxHourlyRuns != 10 { + t.Error("MaxHourlyRuns mismatch") + } +} + +func TestContainerAppAgentDefinition_RoundTrip(t *testing.T) { + t.Parallel() + + original := ContainerAppAgentDefinition{ + AgentDefinition: AgentDefinition{Kind: AgentKindContainerApp}, + ContainerProtocolVersions: []ProtocolVersionRecord{ + {Protocol: AgentProtocolInvocations, Version: "2024-01-01"}, + }, + ContainerAppResourceID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.App/containerApps/app", + IngressSubdomainSuffix: "myapp", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{ + `"container_app_resource_id"`, `"ingress_subdomain_suffix"`, `"container_protocol_versions"`, + } { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got ContainerAppAgentDefinition + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Kind != AgentKindContainerApp { + t.Errorf("Kind = %q, want %q", got.Kind, AgentKindContainerApp) + } + if got.ContainerAppResourceID != original.ContainerAppResourceID { + t.Errorf("ContainerAppResourceID mismatch") + } +} + +func TestWorkflowDefinition_RoundTrip(t *testing.T) { + t.Parallel() + + original := WorkflowDefinition{ + AgentDefinition: AgentDefinition{Kind: AgentKindWorkflow}, + Trigger: map[string]any{"type": "schedule", "cron": "0 * * * *"}, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + if !strings.Contains(s, `"trigger"`) { + t.Error("expected JSON to contain \"trigger\"") + } + + var got WorkflowDefinition + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Kind != AgentKindWorkflow { + t.Errorf("Kind = %q, want %q", got.Kind, AgentKindWorkflow) + } + if got.Trigger["type"] != "schedule" { + t.Errorf("Trigger[type] = %v, want %q", got.Trigger["type"], "schedule") + } +} + +func TestAgentContainerOperationObject_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentContainerOperationObject{ + ID: "op-1", + AgentID: "agent-1", + AgentVersionID: "ver-1", + Status: AgentContainerOperationStatusSucceeded, + Error: &AgentContainerOperationError{ + Code: "E001", + Type: "runtime", + Message: "something went wrong", + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"agent_id"`, `"agent_version_id"`, `"status"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentContainerOperationObject + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if got.Status != AgentContainerOperationStatusSucceeded { + t.Errorf("Status = %q, want %q", got.Status, AgentContainerOperationStatusSucceeded) + } + if got.Error == nil || got.Error.Message != "something went wrong" { + t.Error("Error.Message mismatch") + } +} + +func TestCommonListObjectProperties_RoundTrip(t *testing.T) { + t.Parallel() + + original := AgentList{ + Data: []AgentObject{ + {Object: "agent", ID: "a1", Name: "agent-one"}, + }, + CommonListObjectProperties: CommonListObjectProperties{ + Object: "list", + FirstID: "a1", + LastID: "a1", + HasMore: false, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + s := string(data) + for _, field := range []string{`"first_id"`, `"last_id"`, `"has_more"`} { + if !strings.Contains(s, field) { + t.Errorf("expected JSON to contain %s", field) + } + } + + var got AgentList + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(got.Data) != 1 || got.Data[0].ID != "a1" { + t.Error("Data mismatch") + } + if got.HasMore { + t.Error("HasMore = true, want false") + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go new file mode 100644 index 00000000000..f53452355c9 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go @@ -0,0 +1,1199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agent_yaml + +import ( + "math" + "strings" + "testing" + + "azureaiagent/internal/pkg/agents/agent_api" +) + +// --------------------------------------------------------------------------- +// constructBuildConfig +// --------------------------------------------------------------------------- + +func TestConstructBuildConfig_NoOptions(t *testing.T) { + t.Parallel() + cfg := constructBuildConfig() + if cfg == nil { + t.Fatal("expected non-nil config") + } + if cfg.ImageURL != "" { + t.Errorf("expected empty ImageURL, got %q", cfg.ImageURL) + } + if cfg.CPU != "" { + t.Errorf("expected empty CPU, got %q", cfg.CPU) + } + if cfg.Memory != "" { + t.Errorf("expected empty Memory, got %q", cfg.Memory) + } + if cfg.EnvironmentVariables != nil { + t.Errorf("expected nil EnvironmentVariables, got %v", cfg.EnvironmentVariables) + } +} + +func TestConstructBuildConfig_AllOptions(t *testing.T) { + t.Parallel() + cfg := constructBuildConfig( + WithImageURL("myregistry.azurecr.io/myimage:latest"), + WithCPU("2"), + WithMemory("4Gi"), + WithEnvironmentVariable("KEY1", "val1"), + WithEnvironmentVariables(map[string]string{"KEY2": "val2", "KEY3": "val3"}), + ) + if cfg.ImageURL != "myregistry.azurecr.io/myimage:latest" { + t.Errorf("ImageURL = %q", cfg.ImageURL) + } + if cfg.CPU != "2" { + t.Errorf("CPU = %q", cfg.CPU) + } + if cfg.Memory != "4Gi" { + t.Errorf("Memory = %q", cfg.Memory) + } + if len(cfg.EnvironmentVariables) != 3 { + t.Fatalf("expected 3 env vars, got %d", len(cfg.EnvironmentVariables)) + } + for _, k := range []string{"KEY1", "KEY2", "KEY3"} { + if _, ok := cfg.EnvironmentVariables[k]; !ok { + t.Errorf("missing env var %q", k) + } + } +} + +// --------------------------------------------------------------------------- +// WithEnvironmentVariable / WithEnvironmentVariables +// --------------------------------------------------------------------------- + +func TestWithEnvironmentVariable_InitializesMap(t *testing.T) { + t.Parallel() + cfg := &AgentBuildConfig{} + WithEnvironmentVariable("A", "1")(cfg) + if cfg.EnvironmentVariables["A"] != "1" { + t.Errorf("expected A=1, got %q", cfg.EnvironmentVariables["A"]) + } +} + +func TestWithEnvironmentVariables_MergesIntoExisting(t *testing.T) { + t.Parallel() + cfg := &AgentBuildConfig{EnvironmentVariables: map[string]string{"EXISTING": "x"}} + WithEnvironmentVariables(map[string]string{"NEW": "y"})(cfg) + if cfg.EnvironmentVariables["EXISTING"] != "x" { + t.Error("existing env var was lost") + } + if cfg.EnvironmentVariables["NEW"] != "y" { + t.Error("new env var not set") + } +} + +func TestWithEnvironmentVariables_InitializesNilMap(t *testing.T) { + t.Parallel() + cfg := &AgentBuildConfig{} + WithEnvironmentVariables(map[string]string{"K": "V"})(cfg) + if cfg.EnvironmentVariables["K"] != "V" { + t.Errorf("expected K=V") + } +} + +// --------------------------------------------------------------------------- +// convertIntToInt32 +// --------------------------------------------------------------------------- + +func TestConvertIntToInt32(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input *int + want *int32 + wantErr bool + }{ + { + name: "nil input", + input: nil, + want: nil, + }, + { + name: "zero", + input: new(0), + want: new(int32(0)), + }, + { + name: "positive value", + input: new(42), + want: new(int32(42)), + }, + { + name: "negative value", + input: new(-10), + want: new(int32(-10)), + }, + { + name: "max int32", + input: new(math.MaxInt32), + want: new(int32(math.MaxInt32)), + }, + { + name: "min int32", + input: new(math.MinInt32), + want: new(int32(math.MinInt32)), + }, + { + name: "overflow positive", + input: new(math.MaxInt32 + 1), + wantErr: true, + }, + { + name: "overflow negative", + input: new(math.MinInt32 - 1), + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := convertIntToInt32(tc.input) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.want == nil { + if got != nil { + t.Fatalf("expected nil, got %v", *got) + } + return + } + if got == nil { + t.Fatal("expected non-nil result") + } + if *got != *tc.want { + t.Errorf("got %d, want %d", *got, *tc.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// convertFloat64ToFloat32 +// --------------------------------------------------------------------------- + +func TestConvertFloat64ToFloat32(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input *float64 + isNil bool + }{ + {name: "nil input", input: nil, isNil: true}, + {name: "zero", input: new(0.0)}, + {name: "typical temperature", input: new(0.7)}, + {name: "one", input: new(1.0)}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := convertFloat64ToFloat32(tc.input) + if tc.isNil { + if got != nil { + t.Fatalf("expected nil, got %v", *got) + } + return + } + if got == nil { + t.Fatal("expected non-nil result") + } + expected := float32(*tc.input) + if *got != expected { + t.Errorf("got %v, want %v", *got, expected) + } + }) + } +} + +// --------------------------------------------------------------------------- +// convertYamlToolToApiTool +// --------------------------------------------------------------------------- + +func TestConvertYamlToolToApiTool_Nil(t *testing.T) { + t.Parallel() + _, err := convertYamlToolToApiTool(nil) + if err == nil { + t.Fatal("expected error for nil tool") + } + if !strings.Contains(err.Error(), "nil") { + t.Errorf("error should mention nil, got: %s", err.Error()) + } +} + +func TestConvertYamlToolToApiTool_UnknownType(t *testing.T) { + t.Parallel() + _, err := convertYamlToolToApiTool("not-a-tool") + if err == nil { + t.Fatal("expected error for unknown tool type") + } + if !strings.Contains(err.Error(), "unsupported") { + t.Errorf("error should mention unsupported, got: %s", err.Error()) + } +} + +func TestConvertYamlToolToApiTool_Function(t *testing.T) { + t.Parallel() + desc := "adds two numbers" + yamlTool := FunctionTool{ + Tool: Tool{ + Name: "add", + Kind: ToolKindFunction, + Description: &desc, + }, + Parameters: PropertySchema{}, + Strict: new(true), + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ft, ok := result.(agent_api.FunctionTool) + if !ok { + t.Fatalf("expected agent_api.FunctionTool, got %T", result) + } + if ft.Tool.Type != agent_api.ToolTypeFunction { + t.Errorf("type = %q, want %q", ft.Tool.Type, agent_api.ToolTypeFunction) + } + if ft.Name != "add" { + t.Errorf("name = %q, want %q", ft.Name, "add") + } + if ft.Description == nil || *ft.Description != desc { + t.Errorf("description mismatch") + } + if ft.Strict == nil || !*ft.Strict { + t.Error("strict should be true") + } +} + +func TestConvertYamlToolToApiTool_FunctionNilDescription(t *testing.T) { + t.Parallel() + yamlTool := FunctionTool{ + Tool: Tool{ + Name: "noop", + Kind: ToolKindFunction, + }, + Parameters: PropertySchema{}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ft := result.(agent_api.FunctionTool) + if ft.Description != nil { + t.Errorf("expected nil description, got %v", ft.Description) + } + if ft.Strict != nil { + t.Errorf("expected nil strict, got %v", ft.Strict) + } +} + +func TestConvertYamlToolToApiTool_WebSearch(t *testing.T) { + t.Parallel() + yamlTool := WebSearchTool{ + Tool: Tool{Name: "websearch", Kind: ToolKindWebSearch}, + Options: map[string]any{ + "searchContextSize": "high", + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ws, ok := result.(agent_api.WebSearchPreviewTool) + if !ok { + t.Fatalf("expected agent_api.WebSearchPreviewTool, got %T", result) + } + if ws.Tool.Type != agent_api.ToolTypeWebSearchPreview { + t.Errorf("type = %q, want %q", ws.Tool.Type, agent_api.ToolTypeWebSearchPreview) + } + if ws.SearchContextSize == nil || *ws.SearchContextSize != "high" { + t.Errorf("searchContextSize mismatch") + } +} + +func TestConvertYamlToolToApiTool_WebSearchNoOptions(t *testing.T) { + t.Parallel() + yamlTool := WebSearchTool{ + Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ws := result.(agent_api.WebSearchPreviewTool) + if ws.UserLocation != nil { + t.Error("expected nil UserLocation") + } + if ws.SearchContextSize != nil { + t.Error("expected nil SearchContextSize") + } +} + +func TestConvertYamlToolToApiTool_BingGrounding(t *testing.T) { + t.Parallel() + bgParams := agent_api.BingGroundingSearchToolParameters{ + ProjectConnections: agent_api.ToolProjectConnectionList{ + ProjectConnections: []agent_api.ToolProjectConnection{{ID: "conn-1"}}, + }, + } + yamlTool := BingGroundingTool{ + Tool: Tool{Name: "bing", Kind: ToolKindBingGrounding}, + Options: map[string]any{ + "bingGrounding": bgParams, + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + bg, ok := result.(agent_api.BingGroundingAgentTool) + if !ok { + t.Fatalf("expected agent_api.BingGroundingAgentTool, got %T", result) + } + if bg.Tool.Type != agent_api.ToolTypeBingGrounding { + t.Errorf("type = %q, want %q", bg.Tool.Type, agent_api.ToolTypeBingGrounding) + } + if len(bg.BingGrounding.ProjectConnections.ProjectConnections) != 1 { + t.Errorf("expected 1 project connection") + } +} + +func TestConvertYamlToolToApiTool_BingGroundingNoOptions(t *testing.T) { + t.Parallel() + yamlTool := BingGroundingTool{ + Tool: Tool{Name: "bing", Kind: ToolKindBingGrounding}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + bg := result.(agent_api.BingGroundingAgentTool) + if bg.Tool.Type != agent_api.ToolTypeBingGrounding { + t.Errorf("type = %q", bg.Tool.Type) + } +} + +func TestConvertYamlToolToApiTool_FileSearch(t *testing.T) { + t.Parallel() + ranker := "default-2024-11-15" + threshold := 0.8 + maxResults := 10 + yamlTool := FileSearchTool{ + Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, + VectorStoreIds: []string{"vs-1", "vs-2"}, + MaximumResultCount: &maxResults, + Ranker: &ranker, + ScoreThreshold: &threshold, + Options: map[string]any{ + "filters": map[string]any{"type": "eq", "key": "status", "value": "active"}, + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + fs, ok := result.(agent_api.FileSearchTool) + if !ok { + t.Fatalf("expected agent_api.FileSearchTool, got %T", result) + } + if fs.Tool.Type != agent_api.ToolTypeFileSearch { + t.Errorf("type = %q", fs.Tool.Type) + } + if len(fs.VectorStoreIds) != 2 { + t.Errorf("expected 2 vector store ids, got %d", len(fs.VectorStoreIds)) + } + if fs.MaxNumResults == nil || *fs.MaxNumResults != 10 { + t.Errorf("MaxNumResults mismatch") + } + if fs.RankingOptions == nil { + t.Fatal("expected non-nil RankingOptions") + } + if fs.RankingOptions.Ranker == nil || *fs.RankingOptions.Ranker != ranker { + t.Errorf("ranker mismatch") + } + if fs.RankingOptions.ScoreThreshold == nil || *fs.RankingOptions.ScoreThreshold != float32(threshold) { + t.Errorf("score threshold mismatch") + } + if fs.Filters == nil { + t.Error("expected filters to be set") + } +} + +func TestConvertYamlToolToApiTool_FileSearchMinimal(t *testing.T) { + t.Parallel() + yamlTool := FileSearchTool{ + Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, + VectorStoreIds: []string{"vs-1"}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + fs := result.(agent_api.FileSearchTool) + if fs.MaxNumResults != nil { + t.Error("expected nil MaxNumResults") + } + if fs.RankingOptions != nil { + t.Error("expected nil RankingOptions when ranker and threshold are nil") + } +} + +func TestConvertYamlToolToApiTool_FileSearchOverflow(t *testing.T) { + t.Parallel() + overflow := math.MaxInt32 + 1 + yamlTool := FileSearchTool{ + Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, + VectorStoreIds: []string{"vs-1"}, + MaximumResultCount: &overflow, + } + + _, err := convertYamlToolToApiTool(yamlTool) + if err == nil { + t.Fatal("expected error for int32 overflow") + } + if !strings.Contains(err.Error(), "overflow") { + t.Errorf("error should mention overflow, got: %s", err.Error()) + } +} + +func TestConvertYamlToolToApiTool_MCP(t *testing.T) { + t.Parallel() + yamlTool := McpTool{ + Tool: Tool{Name: "mcp-server", Kind: ToolKindMcp}, + ServerName: "my-mcp-server", + Options: map[string]any{ + "serverUrl": "https://mcp.example.com", + "headers": map[string]string{"Authorization": "Bearer tok"}, + "allowedTools": []string{"tool_a", "tool_b"}, + "requireApproval": "always", + "projectConnectionId": "conn-123", + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + mcp, ok := result.(agent_api.MCPTool) + if !ok { + t.Fatalf("expected agent_api.MCPTool, got %T", result) + } + if mcp.Tool.Type != agent_api.ToolTypeMCP { + t.Errorf("type = %q", mcp.Tool.Type) + } + if mcp.ServerLabel != "my-mcp-server" { + t.Errorf("ServerLabel = %q", mcp.ServerLabel) + } + if mcp.ServerURL != "https://mcp.example.com" { + t.Errorf("ServerURL = %q", mcp.ServerURL) + } + if mcp.Headers["Authorization"] != "Bearer tok" { + t.Errorf("headers mismatch") + } + if mcp.ProjectConnectionID == nil || *mcp.ProjectConnectionID != "conn-123" { + t.Errorf("ProjectConnectionID mismatch") + } +} + +func TestConvertYamlToolToApiTool_MCPNoOptions(t *testing.T) { + t.Parallel() + yamlTool := McpTool{ + Tool: Tool{Name: "mcp", Kind: ToolKindMcp}, + ServerName: "srv", + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + mcp := result.(agent_api.MCPTool) + if mcp.ServerURL != "" { + t.Errorf("expected empty ServerURL, got %q", mcp.ServerURL) + } + if mcp.ProjectConnectionID != nil { + t.Error("expected nil ProjectConnectionID") + } +} + +func TestConvertYamlToolToApiTool_OpenApi(t *testing.T) { + t.Parallel() + openApiDef := agent_api.OpenApiFunctionDefinition{ + Name: "petstore", + Auth: agent_api.OpenApiAuthDetails{Type: agent_api.OpenApiAuthTypeAnonymous}, + } + yamlTool := OpenApiTool{ + Tool: Tool{Name: "petstore", Kind: ToolKindOpenApi}, + Options: map[string]any{ + "openapi": openApiDef, + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + oa, ok := result.(agent_api.OpenApiAgentTool) + if !ok { + t.Fatalf("expected agent_api.OpenApiAgentTool, got %T", result) + } + if oa.Tool.Type != agent_api.ToolTypeOpenAPI { + t.Errorf("type = %q", oa.Tool.Type) + } + if oa.OpenAPI.Name != "petstore" { + t.Errorf("OpenAPI.Name = %q", oa.OpenAPI.Name) + } +} + +func TestConvertYamlToolToApiTool_OpenApiNoOptions(t *testing.T) { + t.Parallel() + yamlTool := OpenApiTool{ + Tool: Tool{Name: "api", Kind: ToolKindOpenApi}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + oa := result.(agent_api.OpenApiAgentTool) + if oa.Tool.Type != agent_api.ToolTypeOpenAPI { + t.Errorf("type = %q", oa.Tool.Type) + } +} + +func TestConvertYamlToolToApiTool_CodeInterpreter(t *testing.T) { + t.Parallel() + yamlTool := CodeInterpreterTool{ + Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}, + Options: map[string]any{ + "container": "auto", + }, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ci, ok := result.(agent_api.CodeInterpreterTool) + if !ok { + t.Fatalf("expected agent_api.CodeInterpreterTool, got %T", result) + } + if ci.Tool.Type != agent_api.ToolTypeCodeInterpreter { + t.Errorf("type = %q", ci.Tool.Type) + } + if ci.Container != "auto" { + t.Errorf("Container = %v", ci.Container) + } +} + +func TestConvertYamlToolToApiTool_CodeInterpreterNoOptions(t *testing.T) { + t.Parallel() + yamlTool := CodeInterpreterTool{ + Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}, + } + + result, err := convertYamlToolToApiTool(yamlTool) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ci := result.(agent_api.CodeInterpreterTool) + if ci.Container != nil { + t.Errorf("expected nil Container, got %v", ci.Container) + } +} + +// --------------------------------------------------------------------------- +// convertYamlToolsToApiTools +// --------------------------------------------------------------------------- + +func TestConvertYamlToolsToApiTools_MixedTools(t *testing.T) { + t.Parallel() + yamlTools := []any{ + FunctionTool{Tool: Tool{Name: "fn1", Kind: ToolKindFunction}, Parameters: PropertySchema{}}, + WebSearchTool{Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}}, + CodeInterpreterTool{Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}}, + } + + result := convertYamlToolsToApiTools(yamlTools) + if len(result) != 3 { + t.Fatalf("expected 3 tools, got %d", len(result)) + } + + if _, ok := result[0].(agent_api.FunctionTool); !ok { + t.Errorf("tool[0] should be FunctionTool, got %T", result[0]) + } + if _, ok := result[1].(agent_api.WebSearchPreviewTool); !ok { + t.Errorf("tool[1] should be WebSearchPreviewTool, got %T", result[1]) + } + if _, ok := result[2].(agent_api.CodeInterpreterTool); !ok { + t.Errorf("tool[2] should be CodeInterpreterTool, got %T", result[2]) + } +} + +func TestConvertYamlToolsToApiTools_SkipsUnsupported(t *testing.T) { + t.Parallel() + yamlTools := []any{ + FunctionTool{Tool: Tool{Name: "fn1", Kind: ToolKindFunction}, Parameters: PropertySchema{}}, + "unsupported-string-tool", + WebSearchTool{Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}}, + } + + result := convertYamlToolsToApiTools(yamlTools) + if len(result) != 2 { + t.Fatalf("expected 2 tools (unsupported skipped), got %d", len(result)) + } +} + +func TestConvertYamlToolsToApiTools_Empty(t *testing.T) { + t.Parallel() + result := convertYamlToolsToApiTools([]any{}) + if result != nil { + t.Errorf("expected nil for empty input, got %v", result) + } +} + +// --------------------------------------------------------------------------- +// createAgentAPIRequest (common fields) +// --------------------------------------------------------------------------- + +func TestCreateAgentAPIRequest_AllFields(t *testing.T) { + t.Parallel() + desc := "A helpful agent" + meta := map[string]any{ + "authors": []any{"Alice", "Bob"}, + "version": "1.0", + } + agentDef := AgentDefinition{ + Kind: AgentKindPrompt, + Name: "my-agent", + Description: &desc, + Metadata: &meta, + } + + req, err := createAgentAPIRequest(agentDef, "placeholder-definition") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "my-agent" { + t.Errorf("Name = %q, want %q", req.Name, "my-agent") + } + if req.Description == nil || *req.Description != desc { + t.Errorf("Description mismatch") + } + if req.Metadata["authors"] != "Alice,Bob" { + t.Errorf("authors = %q, want %q", req.Metadata["authors"], "Alice,Bob") + } + if req.Metadata["version"] != "1.0" { + t.Errorf("version metadata = %q", req.Metadata["version"]) + } + if req.Definition != "placeholder-definition" { + t.Errorf("Definition mismatch") + } +} + +func TestCreateAgentAPIRequest_DefaultName(t *testing.T) { + t.Parallel() + agentDef := AgentDefinition{Kind: AgentKindPrompt} + + req, err := createAgentAPIRequest(agentDef, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "unspecified-agent-name" { + t.Errorf("Name = %q, want %q", req.Name, "unspecified-agent-name") + } +} + +func TestCreateAgentAPIRequest_NilMetadata(t *testing.T) { + t.Parallel() + agentDef := AgentDefinition{Kind: AgentKindPrompt, Name: "test"} + + req, err := createAgentAPIRequest(agentDef, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Metadata != nil { + t.Errorf("expected nil Metadata, got %v", req.Metadata) + } +} + +func TestCreateAgentAPIRequest_EmptyDescription(t *testing.T) { + t.Parallel() + empty := "" + agentDef := AgentDefinition{ + Kind: AgentKindPrompt, + Name: "test", + Description: &empty, + } + + req, err := createAgentAPIRequest(agentDef, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Description != nil { + t.Errorf("expected nil Description for empty string, got %v", req.Description) + } +} + +func TestCreateAgentAPIRequest_MetadataWithNonStringValues(t *testing.T) { + t.Parallel() + meta := map[string]any{ + "name": "test", + "numeric": 42, // non-string value should be skipped + } + agentDef := AgentDefinition{ + Kind: AgentKindPrompt, + Name: "test", + Metadata: &meta, + } + + req, err := createAgentAPIRequest(agentDef, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Metadata["name"] != "test" { + t.Errorf("string metadata missing") + } + if _, exists := req.Metadata["numeric"]; exists { + t.Errorf("non-string metadata should be skipped") + } +} + +func TestCreateAgentAPIRequest_AuthorsSingleAuthor(t *testing.T) { + t.Parallel() + meta := map[string]any{ + "authors": []any{"Solo"}, + } + agentDef := AgentDefinition{ + Kind: AgentKindPrompt, + Name: "test", + Metadata: &meta, + } + + req, err := createAgentAPIRequest(agentDef, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Metadata["authors"] != "Solo" { + t.Errorf("authors = %q, want %q", req.Metadata["authors"], "Solo") + } +} + +// --------------------------------------------------------------------------- +// CreatePromptAgentAPIRequest +// --------------------------------------------------------------------------- + +func TestCreatePromptAgentAPIRequest_FullConfig(t *testing.T) { + t.Parallel() + desc := "prompt agent" + instructions := "You are a helpful assistant." + temp := 0.7 + topP := 0.9 + + agent := PromptAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + Name: "my-prompt-agent", + Description: &desc, + }, + Model: Model{ + Id: "gpt-4o", + Options: &ModelOptions{ + Temperature: &temp, + TopP: &topP, + }, + }, + Instructions: &instructions, + Tools: &[]any{ + FunctionTool{ + Tool: Tool{Name: "calc", Kind: ToolKindFunction}, + Parameters: PropertySchema{}, + }, + }, + } + + req, err := CreatePromptAgentAPIRequest(agent, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "my-prompt-agent" { + t.Errorf("Name = %q", req.Name) + } + if req.Description == nil || *req.Description != desc { + t.Errorf("Description mismatch") + } + + promptDef, ok := req.Definition.(agent_api.PromptAgentDefinition) + if !ok { + t.Fatalf("Definition should be PromptAgentDefinition, got %T", req.Definition) + } + if promptDef.Kind != agent_api.AgentKindPrompt { + t.Errorf("Kind = %q", promptDef.Kind) + } + if promptDef.Model != "gpt-4o" { + t.Errorf("Model = %q", promptDef.Model) + } + if promptDef.Instructions == nil || *promptDef.Instructions != instructions { + t.Errorf("Instructions mismatch") + } + if promptDef.Temperature == nil || *promptDef.Temperature != float32(0.7) { + t.Errorf("Temperature mismatch") + } + if promptDef.TopP == nil || *promptDef.TopP != float32(0.9) { + t.Errorf("TopP mismatch") + } + if len(promptDef.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(promptDef.Tools)) + } +} + +func TestCreatePromptAgentAPIRequest_MissingModelId(t *testing.T) { + t.Parallel() + agent := PromptAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + Name: "bad-agent", + }, + Model: Model{Id: ""}, + Tools: &[]any{}, + } + + _, err := CreatePromptAgentAPIRequest(agent, nil) + if err == nil { + t.Fatal("expected error for missing model.id") + } + if !strings.Contains(err.Error(), "model.id") { + t.Errorf("error should mention model.id, got: %s", err.Error()) + } +} + +func TestCreatePromptAgentAPIRequest_NoOptions(t *testing.T) { + t.Parallel() + agent := PromptAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + Name: "simple-agent", + }, + Model: Model{Id: "gpt-4o-mini"}, + Tools: &[]any{}, + } + + req, err := CreatePromptAgentAPIRequest(agent, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + promptDef := req.Definition.(agent_api.PromptAgentDefinition) + if promptDef.Temperature != nil { + t.Errorf("expected nil Temperature, got %v", *promptDef.Temperature) + } + if promptDef.TopP != nil { + t.Errorf("expected nil TopP, got %v", *promptDef.TopP) + } + if promptDef.Instructions != nil { + t.Errorf("expected nil Instructions") + } +} + +func TestCreatePromptAgentAPIRequest_NilToolsSlice(t *testing.T) { + t.Parallel() + emptyTools := []any{} + agent := PromptAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + Name: "no-tools", + }, + Model: Model{Id: "gpt-4o"}, + Tools: &emptyTools, + } + + req, err := CreatePromptAgentAPIRequest(agent, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + promptDef := req.Definition.(agent_api.PromptAgentDefinition) + if promptDef.Tools != nil { + t.Errorf("expected nil Tools for empty input, got %v", promptDef.Tools) + } +} + +// --------------------------------------------------------------------------- +// CreateHostedAgentAPIRequest +// --------------------------------------------------------------------------- + +func TestCreateHostedAgentAPIRequest_FullConfig(t *testing.T) { + t.Parallel() + desc := "hosted agent" + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "my-hosted-agent", + Description: &desc, + }, + Protocols: []ProtocolVersionRecord{ + {Protocol: "responses", Version: "2.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, + }, + } + + buildConfig := &AgentBuildConfig{ + ImageURL: "myregistry.azurecr.io/agent:v1", + CPU: "4", + Memory: "8Gi", + EnvironmentVariables: map[string]string{"ENV1": "val1"}, + } + + req, err := CreateHostedAgentAPIRequest(agent, buildConfig) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "my-hosted-agent" { + t.Errorf("Name = %q", req.Name) + } + if req.Description == nil || *req.Description != desc { + t.Errorf("Description mismatch") + } + + imgDef, ok := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) + if !ok { + t.Fatalf("expected ImageBasedHostedAgentDefinition, got %T", req.Definition) + } + if imgDef.Kind != agent_api.AgentKindHosted { + t.Errorf("Kind = %q", imgDef.Kind) + } + if imgDef.Image != "myregistry.azurecr.io/agent:v1" { + t.Errorf("Image = %q", imgDef.Image) + } + if imgDef.CPU != "4" { + t.Errorf("CPU = %q", imgDef.CPU) + } + if imgDef.Memory != "8Gi" { + t.Errorf("Memory = %q", imgDef.Memory) + } + if imgDef.EnvironmentVariables["ENV1"] != "val1" { + t.Error("env var missing") + } + + // Verify protocol versions + if len(imgDef.ContainerProtocolVersions) != 2 { + t.Fatalf("expected 2 protocol versions, got %d", len(imgDef.ContainerProtocolVersions)) + } + if imgDef.ContainerProtocolVersions[0].Protocol != "responses" { + t.Errorf("protocol[0] = %q", imgDef.ContainerProtocolVersions[0].Protocol) + } + if imgDef.ContainerProtocolVersions[0].Version != "2.0.0" { + t.Errorf("version[0] = %q", imgDef.ContainerProtocolVersions[0].Version) + } +} + +func TestCreateHostedAgentAPIRequest_DefaultProtocols(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "default-protocols", + }, + } + buildConfig := &AgentBuildConfig{ImageURL: "img:latest"} + + req, err := CreateHostedAgentAPIRequest(agent, buildConfig) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) + if len(imgDef.ContainerProtocolVersions) != 1 { + t.Fatalf("expected 1 default protocol, got %d", len(imgDef.ContainerProtocolVersions)) + } + if imgDef.ContainerProtocolVersions[0].Protocol != agent_api.AgentProtocolResponses { + t.Errorf("default protocol = %q", imgDef.ContainerProtocolVersions[0].Protocol) + } + if imgDef.ContainerProtocolVersions[0].Version != "1.0.0" { + t.Errorf("default version = %q", imgDef.ContainerProtocolVersions[0].Version) + } +} + +func TestCreateHostedAgentAPIRequest_DefaultCPUAndMemory(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "defaults", + }, + } + buildConfig := &AgentBuildConfig{ImageURL: "img:latest"} + + req, err := CreateHostedAgentAPIRequest(agent, buildConfig) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) + if imgDef.CPU != "1" { + t.Errorf("default CPU = %q, want %q", imgDef.CPU, "1") + } + if imgDef.Memory != "2Gi" { + t.Errorf("default Memory = %q, want %q", imgDef.Memory, "2Gi") + } +} + +func TestCreateHostedAgentAPIRequest_MissingImageURL(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "no-image", + }, + } + + _, err := CreateHostedAgentAPIRequest(agent, &AgentBuildConfig{}) + if err == nil { + t.Fatal("expected error for missing image URL") + } + if !strings.Contains(err.Error(), "image URL") { + t.Errorf("error should mention image URL, got: %s", err.Error()) + } +} + +func TestCreateHostedAgentAPIRequest_NilBuildConfig(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "nil-config", + }, + } + + _, err := CreateHostedAgentAPIRequest(agent, nil) + if err == nil { + t.Fatal("expected error for nil build config (no image)") + } +} + +// --------------------------------------------------------------------------- +// CreateAgentAPIRequestFromDefinition (routing) +// --------------------------------------------------------------------------- + +func TestCreateAgentAPIRequestFromDefinition_PromptAgent(t *testing.T) { + t.Parallel() + agent := PromptAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindPrompt, + Name: "prompt-routed", + }, + Model: Model{Id: "gpt-4o"}, + Tools: &[]any{}, + } + + req, err := CreateAgentAPIRequestFromDefinition(agent) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "prompt-routed" { + t.Errorf("Name = %q", req.Name) + } + + _, ok := req.Definition.(agent_api.PromptAgentDefinition) + if !ok { + t.Fatalf("expected PromptAgentDefinition, got %T", req.Definition) + } +} + +func TestCreateAgentAPIRequestFromDefinition_HostedAgent(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "hosted-routed", + }, + } + + req, err := CreateAgentAPIRequestFromDefinition(agent, WithImageURL("img:latest")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.Name != "hosted-routed" { + t.Errorf("Name = %q", req.Name) + } + + _, ok := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) + if !ok { + t.Fatalf("expected ImageBasedHostedAgentDefinition, got %T", req.Definition) + } +} + +func TestCreateAgentAPIRequestFromDefinition_UnsupportedKind(t *testing.T) { + t.Parallel() + agent := AgentDefinition{ + Kind: "unknown", + Name: "bad-kind", + } + + _, err := CreateAgentAPIRequestFromDefinition(agent) + if err == nil { + t.Fatal("expected error for unsupported kind") + } + if !strings.Contains(err.Error(), "unsupported agent kind") { + t.Errorf("error should mention unsupported agent kind, got: %s", err.Error()) + } +} + +func TestCreateAgentAPIRequestFromDefinition_HostedWithBuildOptions(t *testing.T) { + t.Parallel() + agent := ContainerAgent{ + AgentDefinition: AgentDefinition{ + Kind: AgentKindHosted, + Name: "hosted-opts", + }, + Protocols: []ProtocolVersionRecord{ + {Protocol: "responses", Version: "1.0.0"}, + }, + } + + req, err := CreateAgentAPIRequestFromDefinition(agent, + WithImageURL("myimg:v2"), + WithCPU("2"), + WithMemory("4Gi"), + WithEnvironmentVariable("FOO", "bar"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) + if imgDef.Image != "myimg:v2" { + t.Errorf("Image = %q", imgDef.Image) + } + if imgDef.CPU != "2" { + t.Errorf("CPU = %q", imgDef.CPU) + } + if imgDef.Memory != "4Gi" { + t.Errorf("Memory = %q", imgDef.Memory) + } + if imgDef.EnvironmentVariables["FOO"] != "bar" { + t.Errorf("env var FOO missing or wrong") + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml new file mode 100644 index 00000000000..a6a5314c78b --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml @@ -0,0 +1,9 @@ +template: + kind: hosted + name: hosted-test-agent + description: A hosted container agent for testing + protocols: + - protocol: responses + version: "1.0.0" + - protocol: invocations + version: "1.0.0" diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml new file mode 100644 index 00000000000..2676f0d4033 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml @@ -0,0 +1 @@ +template: {} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml new file mode 100644 index 00000000000..5ba3da54847 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml @@ -0,0 +1,4 @@ +template: + name: no-kind-agent + model: + id: gpt-4o diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml new file mode 100644 index 00000000000..1f5f9536a65 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml @@ -0,0 +1,4 @@ +template: + kind: prompt + name: no-model-agent + instructions: Some instructions diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml new file mode 100644 index 00000000000..9e2a70ed8fa --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml @@ -0,0 +1,18 @@ +template: + kind: prompt + name: mcp-tools-agent + description: Agent with MCP tool connections + model: + id: gpt-4o + tools: + - kind: mcp + name: github-mcp + connection: + kind: foundry + endpoint: https://api.githubcopilot.com/mcp/ + name: github-mcp-conn + url: https://api.githubcopilot.com/mcp/ + - kind: code_interpreter + name: code-runner + instructions: | + You have access to GitHub via MCP and can run code. diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml new file mode 100644 index 00000000000..b5643ea337a --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml @@ -0,0 +1,39 @@ +template: + kind: prompt + name: full-prompt-agent + description: A fully configured prompt agent for testing + metadata: + authors: + - testauthor + tags: + - testing + - full + model: + id: gpt-4o + publisher: azure + options: + temperature: 0.8 + maxTokens: 4000 + topP: 0.95 + instructions: | + You are a helpful testing assistant. + Always respond in a structured format. + tools: + - kind: web_search + name: web-search + - kind: function + name: get_weather + description: Get weather for a location + parameters: + properties: + - name: location + kind: string + description: The city name + required: true + - name: unit + kind: string + description: Temperature unit + enumValues: + - celsius + - fahrenheit + default: celsius diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml new file mode 100644 index 00000000000..7f598558883 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml @@ -0,0 +1,5 @@ +template: + kind: prompt + name: minimal-agent + model: + id: gpt-4o-mini diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go new file mode 100644 index 00000000000..0482af53abf --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agent_yaml + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "go.yaml.in/yaml/v3" +) + +// TestFixtures_ValidYAML verifies that valid YAML fixtures parse successfully +// and produce the expected agent kind and name via ExtractAgentDefinition. +func TestFixtures_ValidYAML(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file string + wantKind AgentKind + wantName string + wantErrSubst string // if non-empty, expect this error instead of success + }{ + { + name: "hosted agent", + file: filepath.Join("testdata", "hosted-agent.yaml"), + wantKind: AgentKindHosted, + wantName: "hosted-test-agent", + }, + { + // Prompt agents are not currently supported by ExtractAgentDefinition. + // This test documents the current expected behavior. + name: "prompt agent minimal", + file: filepath.Join("testdata", "prompt-agent-minimal.yaml"), + wantErrSubst: "prompt agents not currently supported", + }, + { + name: "prompt agent full", + file: filepath.Join("testdata", "prompt-agent-full.yaml"), + wantErrSubst: "prompt agents not currently supported", + }, + { + name: "mcp tools agent", + file: filepath.Join("testdata", "mcp-tools-agent.yaml"), + wantErrSubst: "prompt agents not currently supported", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := os.ReadFile(tc.file) + if err != nil { + t.Fatalf("failed to read fixture %s: %v", tc.file, err) + } + + agent, err := ExtractAgentDefinition(data) + + if tc.wantErrSubst != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErrSubst) + } + if !strings.Contains(err.Error(), tc.wantErrSubst) { + t.Fatalf("expected error containing %q, got %q", tc.wantErrSubst, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("ExtractAgentDefinition failed: %v", err) + } + + containerAgent, ok := agent.(ContainerAgent) + if !ok { + t.Fatalf("expected ContainerAgent, got %T", agent) + } + + if containerAgent.Kind != tc.wantKind { + t.Errorf("kind: got %q, want %q", containerAgent.Kind, tc.wantKind) + } + if containerAgent.Name != tc.wantName { + t.Errorf("name: got %q, want %q", containerAgent.Name, tc.wantName) + } + }) + } +} + +// TestFixtures_ValidatePromptAgents uses ValidateAgentDefinition to confirm +// that prompt agent fixtures have a structurally valid YAML schema, even though +// ExtractAgentDefinition does not yet support prompt agents. +func TestFixtures_ValidatePromptAgents(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file string + }{ + {name: "prompt agent minimal", file: filepath.Join("testdata", "prompt-agent-minimal.yaml")}, + {name: "prompt agent full", file: filepath.Join("testdata", "prompt-agent-full.yaml")}, + {name: "mcp tools agent", file: filepath.Join("testdata", "mcp-tools-agent.yaml")}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := os.ReadFile(tc.file) + if err != nil { + t.Fatalf("failed to read fixture %s: %v", tc.file, err) + } + + // Extract the template section to pass to ValidateAgentDefinition, + // which operates on template bytes rather than the full manifest. + templateBytes, err := extractTemplateBytes(data) + if err != nil { + t.Fatalf("failed to extract template bytes: %v", err) + } + + if err := ValidateAgentDefinition(templateBytes); err != nil { + t.Fatalf("ValidateAgentDefinition failed for valid fixture: %v", err) + } + }) + } +} + +// TestFixtures_InvalidYAML verifies that invalid YAML fixtures return appropriate errors. +func TestFixtures_InvalidYAML(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file string + wantErrSubst string + }{ + { + name: "missing kind field", + file: filepath.Join("testdata", "invalid-no-kind.yaml"), + wantErrSubst: "template.kind must be one of", + }, + { + name: "prompt agent missing model", + file: filepath.Join("testdata", "invalid-no-model.yaml"), + wantErrSubst: "template.model.id is required", + }, + { + name: "empty template", + file: filepath.Join("testdata", "invalid-empty-template.yaml"), + wantErrSubst: "template field is empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := os.ReadFile(tc.file) + if err != nil { + t.Fatalf("failed to read fixture %s: %v", tc.file, err) + } + + // For "empty template", ExtractAgentDefinition catches the error. + // For schema validation errors, ValidateAgentDefinition is used on the template bytes. + _, extractErr := ExtractAgentDefinition(data) + + if extractErr != nil && strings.Contains(extractErr.Error(), tc.wantErrSubst) { + return // error caught at extraction level + } + + // Try validation-level check for schema errors (no-kind, no-model). + templateBytes, err := extractTemplateBytes(data) + if err != nil { + // If we can't even extract template bytes but got an extraction error, that's fine. + if extractErr != nil { + t.Logf("ExtractAgentDefinition error: %v", extractErr) + return + } + t.Fatalf("failed to extract template bytes and no extraction error: %v", err) + } + + validateErr := ValidateAgentDefinition(templateBytes) + if validateErr == nil { + t.Fatalf("expected validation error containing %q, got nil (extractErr=%v)", + tc.wantErrSubst, extractErr) + } + if !strings.Contains(validateErr.Error(), tc.wantErrSubst) { + t.Fatalf("expected error containing %q, got %q", tc.wantErrSubst, validateErr.Error()) + } + }) + } +} + +// TestFixtures_SampleAgents is a regression test that ensures the sample agent +// YAML files in tests/samples/ continue to parse correctly. +func TestFixtures_SampleAgents(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file string + wantName string + }{ + { + name: "declarativeNoTools sample", + file: filepath.Join("..", "..", "..", "..", "tests", "samples", "declarativeNoTools", "agent.yaml"), + wantName: "Learn French Agent", + }, + { + name: "githubMcpAgent sample", + file: filepath.Join("..", "..", "..", "..", "tests", "samples", "githubMcpAgent", "agent.yaml"), + wantName: "github-agent", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := os.ReadFile(tc.file) + if err != nil { + t.Fatalf("failed to read sample %s: %v", tc.file, err) + } + + // Both samples are prompt agents, so ExtractAgentDefinition returns + // "prompt agents not currently supported". Validate structure instead. + _, extractErr := ExtractAgentDefinition(data) + if extractErr == nil { + t.Fatal("expected error for prompt agent sample, got nil") + } + if !strings.Contains(extractErr.Error(), "prompt agents not currently supported") { + t.Fatalf("unexpected error: %v", extractErr) + } + + // Validate that the YAML structure is well-formed by unmarshaling + // the template section into the typed structs. + templateBytes, err := extractTemplateBytes(data) + if err != nil { + t.Fatalf("failed to extract template bytes: %v", err) + } + + var agentDef AgentDefinition + if err := yaml.Unmarshal(templateBytes, &agentDef); err != nil { + t.Fatalf("failed to unmarshal AgentDefinition: %v", err) + } + if agentDef.Name != tc.wantName { + t.Errorf("name: got %q, want %q", agentDef.Name, tc.wantName) + } + if agentDef.Kind != AgentKindPrompt { + t.Errorf("kind: got %q, want %q", agentDef.Kind, AgentKindPrompt) + } + + // Also confirm the model is present for these prompt agents. + var promptAgent PromptAgent + if err := yaml.Unmarshal(templateBytes, &promptAgent); err != nil { + t.Fatalf("failed to unmarshal PromptAgent: %v", err) + } + if promptAgent.Model.Id == "" { + t.Error("expected non-empty model.id in sample agent") + } + }) + } +} + +// extractTemplateBytes reads YAML content with a top-level "template" field +// and returns the marshaled bytes of just the template section. +func extractTemplateBytes(manifestYaml []byte) ([]byte, error) { + var generic map[string]any + if err := yaml.Unmarshal(manifestYaml, &generic); err != nil { + return nil, err + } + + templateVal, ok := generic["template"] + if !ok || templateVal == nil { + return nil, os.ErrNotExist + } + + templateMap, ok := templateVal.(map[string]any) + if !ok { + return nil, os.ErrInvalid + } + + return yaml.Marshal(templateMap) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go new file mode 100644 index 00000000000..000e6083839 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go @@ -0,0 +1,863 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package registry_api + +import ( + "strings" + "testing" + + "azureaiagent/internal/pkg/agents/agent_api" + "azureaiagent/internal/pkg/agents/agent_yaml" +) + +// ptr is a generic helper that returns a pointer to the given value. +// --------------------------------------------------------------------------- +// ConvertToolToYaml +// --------------------------------------------------------------------------- + +func TestConvertToolToYaml(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input any + wantErr bool + errSubstr string + validate func(t *testing.T, got any) + }{ + { + name: "nil tool returns error", + input: nil, + wantErr: true, + errSubstr: "tool cannot be nil", + }, + { + name: "unsupported type returns error", + input: "not-a-tool", + wantErr: true, + errSubstr: "unsupported tool type", + }, + { + name: "FunctionTool", + input: agent_api.FunctionTool{ + Tool: agent_api.Tool{Type: "function"}, + Name: "my_func", + Description: new("a helper function"), + Parameters: nil, + Strict: new(true), + }, + validate: func(t *testing.T, got any) { + ft, ok := got.(agent_yaml.FunctionTool) + if !ok { + t.Fatalf("expected agent_yaml.FunctionTool, got %T", got) + } + if ft.Tool.Kind != agent_yaml.ToolKindFunction { + t.Errorf("Kind = %q, want %q", ft.Tool.Kind, agent_yaml.ToolKindFunction) + } + if ft.Tool.Name != "my_func" { + t.Errorf("Name = %q, want %q", ft.Tool.Name, "my_func") + } + if ft.Tool.Description == nil || *ft.Tool.Description != "a helper function" { + t.Errorf("Description mismatch") + } + if ft.Strict == nil || *ft.Strict != true { + t.Errorf("Strict = %v, want true", ft.Strict) + } + }, + }, + { + name: "WebSearchPreviewTool with options", + input: agent_api.WebSearchPreviewTool{ + Tool: agent_api.Tool{Type: "web_search_preview"}, + UserLocation: &agent_api.Location{Type: "approximate"}, + SearchContextSize: new("medium"), + }, + validate: func(t *testing.T, got any) { + ws, ok := got.(agent_yaml.WebSearchTool) + if !ok { + t.Fatalf("expected agent_yaml.WebSearchTool, got %T", got) + } + if ws.Tool.Kind != agent_yaml.ToolKindWebSearch { + t.Errorf("Kind = %q, want %q", ws.Tool.Kind, agent_yaml.ToolKindWebSearch) + } + if ws.Tool.Name != "web_search_preview" { + t.Errorf("Name = %q, want %q", ws.Tool.Name, "web_search_preview") + } + if ws.Options == nil { + t.Fatal("Options is nil") + } + if _, exists := ws.Options["userLocation"]; !exists { + t.Error("expected userLocation in Options") + } + if ws.Options["searchContextSize"] != "medium" { + t.Errorf("searchContextSize = %v, want %q", ws.Options["searchContextSize"], "medium") + } + }, + }, + { + name: "WebSearchPreviewTool without options", + input: agent_api.WebSearchPreviewTool{ + Tool: agent_api.Tool{Type: "web_search_preview"}, + }, + validate: func(t *testing.T, got any) { + ws, ok := got.(agent_yaml.WebSearchTool) + if !ok { + t.Fatalf("expected agent_yaml.WebSearchTool, got %T", got) + } + // Options map is always created but should be empty + if len(ws.Options) != 0 { + t.Errorf("expected empty Options, got %v", ws.Options) + } + }, + }, + { + name: "BingGroundingAgentTool", + input: agent_api.BingGroundingAgentTool{ + Tool: agent_api.Tool{Type: "bing_grounding"}, + BingGrounding: agent_api.BingGroundingSearchToolParameters{ + ProjectConnections: agent_api.ToolProjectConnectionList{ + ProjectConnections: []agent_api.ToolProjectConnection{ + {ID: "conn-1"}, + }, + }, + }, + }, + validate: func(t *testing.T, got any) { + bg, ok := got.(agent_yaml.BingGroundingTool) + if !ok { + t.Fatalf("expected agent_yaml.BingGroundingTool, got %T", got) + } + if bg.Tool.Kind != agent_yaml.ToolKindBingGrounding { + t.Errorf("Kind = %q, want %q", bg.Tool.Kind, agent_yaml.ToolKindBingGrounding) + } + if bg.Tool.Name != "bing_grounding" { + t.Errorf("Name = %q, want %q", bg.Tool.Name, "bing_grounding") + } + if bg.Options == nil { + t.Fatal("Options is nil") + } + if _, exists := bg.Options["bingGrounding"]; !exists { + t.Error("expected bingGrounding in Options") + } + }, + }, + { + name: "FileSearchTool with ranking options", + input: agent_api.FileSearchTool{ + Tool: agent_api.Tool{Type: "file_search"}, + VectorStoreIds: []string{"vs-1", "vs-2"}, + MaxNumResults: new(int32(10)), + RankingOptions: &agent_api.RankingOptions{ + Ranker: new("auto"), + ScoreThreshold: new(float32(0.5)), + }, + }, + validate: func(t *testing.T, got any) { + fs, ok := got.(agent_yaml.FileSearchTool) + if !ok { + t.Fatalf("expected agent_yaml.FileSearchTool, got %T", got) + } + if fs.Tool.Kind != agent_yaml.ToolKindFileSearch { + t.Errorf("Kind = %q, want %q", fs.Tool.Kind, agent_yaml.ToolKindFileSearch) + } + if len(fs.VectorStoreIds) != 2 || fs.VectorStoreIds[0] != "vs-1" { + t.Errorf("VectorStoreIds = %v, want [vs-1 vs-2]", fs.VectorStoreIds) + } + if fs.MaximumResultCount == nil || *fs.MaximumResultCount != 10 { + t.Errorf("MaximumResultCount = %v, want 10", fs.MaximumResultCount) + } + if fs.Ranker == nil || *fs.Ranker != "auto" { + t.Errorf("Ranker = %v, want auto", fs.Ranker) + } + if fs.ScoreThreshold == nil || *fs.ScoreThreshold != float64(float32(0.5)) { + t.Errorf("ScoreThreshold = %v, want 0.5", fs.ScoreThreshold) + } + }, + }, + { + name: "FileSearchTool without ranking options", + input: agent_api.FileSearchTool{ + Tool: agent_api.Tool{Type: "file_search"}, + VectorStoreIds: []string{"vs-1"}, + }, + validate: func(t *testing.T, got any) { + fs, ok := got.(agent_yaml.FileSearchTool) + if !ok { + t.Fatalf("expected agent_yaml.FileSearchTool, got %T", got) + } + if fs.Ranker != nil { + t.Errorf("Ranker = %v, want nil", fs.Ranker) + } + if fs.ScoreThreshold != nil { + t.Errorf("ScoreThreshold = %v, want nil", fs.ScoreThreshold) + } + if fs.MaximumResultCount != nil { + t.Errorf("MaximumResultCount = %v, want nil", fs.MaximumResultCount) + } + }, + }, + { + name: "MCPTool with all fields", + input: agent_api.MCPTool{ + Tool: agent_api.Tool{Type: "mcp"}, + ServerLabel: "my-server", + ServerURL: "https://example.com", + Headers: map[string]string{"x-key": "val"}, + ProjectConnectionID: new("conn-1"), + }, + validate: func(t *testing.T, got any) { + mcp, ok := got.(agent_yaml.McpTool) + if !ok { + t.Fatalf("expected agent_yaml.McpTool, got %T", got) + } + if mcp.Tool.Kind != agent_yaml.ToolKindMcp { + t.Errorf("Kind = %q, want %q", mcp.Tool.Kind, agent_yaml.ToolKindMcp) + } + if mcp.ServerName != "my-server" { + t.Errorf("ServerName = %q, want %q", mcp.ServerName, "my-server") + } + if mcp.Options["serverUrl"] != "https://example.com" { + t.Errorf("serverUrl = %v, want %q", mcp.Options["serverUrl"], "https://example.com") + } + if mcp.Options["projectConnectionId"] != "conn-1" { + t.Errorf("projectConnectionId = %v, want %q", mcp.Options["projectConnectionId"], "conn-1") + } + headers, ok := mcp.Options["headers"].(map[string]string) + if !ok { + t.Fatalf("expected headers map[string]string, got %T", mcp.Options["headers"]) + } + if headers["x-key"] != "val" { + t.Errorf("header x-key = %q, want %q", headers["x-key"], "val") + } + }, + }, + { + name: "MCPTool minimal", + input: agent_api.MCPTool{ + Tool: agent_api.Tool{Type: "mcp"}, + ServerLabel: "minimal-server", + }, + validate: func(t *testing.T, got any) { + mcp, ok := got.(agent_yaml.McpTool) + if !ok { + t.Fatalf("expected agent_yaml.McpTool, got %T", got) + } + if mcp.ServerName != "minimal-server" { + t.Errorf("ServerName = %q, want %q", mcp.ServerName, "minimal-server") + } + // serverUrl should not appear when empty string + if _, exists := mcp.Options["serverUrl"]; exists { + t.Error("expected serverUrl to be absent for empty ServerURL") + } + if _, exists := mcp.Options["headers"]; exists { + t.Error("expected headers to be absent when nil") + } + if _, exists := mcp.Options["projectConnectionId"]; exists { + t.Error("expected projectConnectionId to be absent when nil") + } + }, + }, + { + name: "OpenApiAgentTool", + input: agent_api.OpenApiAgentTool{ + Tool: agent_api.Tool{Type: "openapi"}, + OpenAPI: agent_api.OpenApiFunctionDefinition{ + Name: "weather-api", + Description: new("Weather lookup"), + }, + }, + validate: func(t *testing.T, got any) { + oa, ok := got.(agent_yaml.OpenApiTool) + if !ok { + t.Fatalf("expected agent_yaml.OpenApiTool, got %T", got) + } + if oa.Tool.Kind != agent_yaml.ToolKindOpenApi { + t.Errorf("Kind = %q, want %q", oa.Tool.Kind, agent_yaml.ToolKindOpenApi) + } + if oa.Tool.Name != "openapi" { + t.Errorf("Name = %q, want %q", oa.Tool.Name, "openapi") + } + if _, exists := oa.Options["openapi"]; !exists { + t.Error("expected openapi in Options") + } + }, + }, + { + name: "CodeInterpreterTool with container", + input: agent_api.CodeInterpreterTool{ + Tool: agent_api.Tool{Type: "code_interpreter"}, + Container: "container-id-123", + }, + validate: func(t *testing.T, got any) { + ci, ok := got.(agent_yaml.CodeInterpreterTool) + if !ok { + t.Fatalf("expected agent_yaml.CodeInterpreterTool, got %T", got) + } + if ci.Tool.Kind != agent_yaml.ToolKindCodeInterpreter { + t.Errorf("Kind = %q, want %q", ci.Tool.Kind, agent_yaml.ToolKindCodeInterpreter) + } + if ci.Tool.Name != "code_interpreter" { + t.Errorf("Name = %q, want %q", ci.Tool.Name, "code_interpreter") + } + if ci.Options["container"] != "container-id-123" { + t.Errorf("container = %v, want %q", ci.Options["container"], "container-id-123") + } + }, + }, + { + name: "CodeInterpreterTool without container", + input: agent_api.CodeInterpreterTool{ + Tool: agent_api.Tool{Type: "code_interpreter"}, + Container: nil, + }, + validate: func(t *testing.T, got any) { + ci, ok := got.(agent_yaml.CodeInterpreterTool) + if !ok { + t.Fatalf("expected agent_yaml.CodeInterpreterTool, got %T", got) + } + if _, exists := ci.Options["container"]; exists { + t.Error("expected container to be absent when nil") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ConvertToolToYaml(tc.input) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tc.errSubstr != "" && !strings.Contains(err.Error(), tc.errSubstr) { + t.Errorf("error %q does not contain %q", err, tc.errSubstr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.validate != nil { + tc.validate(t, got) + } + }) + } +} + +// --------------------------------------------------------------------------- +// ConvertAgentDefinition +// --------------------------------------------------------------------------- + +func TestConvertAgentDefinition(t *testing.T) { + t.Parallel() + + t.Run("empty tools", func(t *testing.T) { + t.Parallel() + def := agent_api.PromptAgentDefinition{ + Model: "gpt-4o", + Instructions: new("Be helpful"), + } + + got, err := ConvertAgentDefinition(def) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AgentDefinition.Kind != agent_yaml.AgentKindPrompt { + t.Errorf("Kind = %q, want %q", got.AgentDefinition.Kind, agent_yaml.AgentKindPrompt) + } + if got.Model.Id != "gpt-4o" { + t.Errorf("Model.Id = %q, want %q", got.Model.Id, "gpt-4o") + } + if got.Instructions == nil || *got.Instructions != "Be helpful" { + t.Errorf("Instructions mismatch") + } + if got.Tools == nil { + t.Fatal("Tools should not be nil") + } + if len(*got.Tools) != 0 { + t.Errorf("expected 0 tools, got %d", len(*got.Tools)) + } + }) + + t.Run("with tools", func(t *testing.T) { + t.Parallel() + def := agent_api.PromptAgentDefinition{ + Model: "gpt-4o-mini", + Instructions: new("Do things"), + Tools: []any{ + agent_api.FunctionTool{ + Tool: agent_api.Tool{Type: "function"}, + Name: "fn1", + }, + agent_api.CodeInterpreterTool{ + Tool: agent_api.Tool{Type: "code_interpreter"}, + }, + }, + } + + got, err := ConvertAgentDefinition(def) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(*got.Tools) != 2 { + t.Fatalf("expected 2 tools, got %d", len(*got.Tools)) + } + // Verify first tool is FunctionTool + if _, ok := (*got.Tools)[0].(agent_yaml.FunctionTool); !ok { + t.Errorf("expected first tool to be FunctionTool, got %T", (*got.Tools)[0]) + } + // Verify second tool is CodeInterpreterTool + if _, ok := (*got.Tools)[1].(agent_yaml.CodeInterpreterTool); !ok { + t.Errorf("expected second tool to be CodeInterpreterTool, got %T", (*got.Tools)[1]) + } + }) + + t.Run("unsupported tool propagates error", func(t *testing.T) { + t.Parallel() + def := agent_api.PromptAgentDefinition{ + Model: "gpt-4o", + Tools: []any{"bad-tool"}, + } + _, err := ConvertAgentDefinition(def) + if err == nil { + t.Fatal("expected error for unsupported tool") + } + if !strings.Contains(err.Error(), "unsupported tool type") { + t.Errorf("error %q does not mention unsupported tool type", err) + } + }) + + t.Run("nil instructions", func(t *testing.T) { + t.Parallel() + def := agent_api.PromptAgentDefinition{ + Model: "gpt-4o", + } + got, err := ConvertAgentDefinition(def) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Instructions != nil { + t.Errorf("Instructions = %v, want nil", got.Instructions) + } + }) +} + +// --------------------------------------------------------------------------- +// ConvertParameters +// --------------------------------------------------------------------------- + +func TestConvertParameters(t *testing.T) { + t.Parallel() + + t.Run("nil parameters returns nil", func(t *testing.T) { + t.Parallel() + got, err := ConvertParameters(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("expected nil, got %+v", got) + } + }) + + t.Run("empty parameters returns nil", func(t *testing.T) { + t.Parallel() + got, err := ConvertParameters(map[string]OpenApiParameter{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("expected nil, got %+v", got) + } + }) + + t.Run("single parameter with schema and enum", func(t *testing.T) { + t.Parallel() + params := map[string]OpenApiParameter{ + "region": { + Description: "Azure region", + Required: true, + Schema: &OpenApiSchema{ + Type: "string", + Enum: []any{"eastus", "westus"}, + }, + }, + } + + got, err := ConvertParameters(params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got == nil { + t.Fatal("expected non-nil PropertySchema") + } + if len(got.Properties) != 1 { + t.Fatalf("expected 1 property, got %d", len(got.Properties)) + } + p := got.Properties[0] + if p.Name != "region" { + t.Errorf("Name = %q, want %q", p.Name, "region") + } + if p.Kind != "string" { + t.Errorf("Kind = %q, want %q", p.Kind, "string") + } + if p.Description == nil || *p.Description != "Azure region" { + t.Errorf("Description mismatch") + } + if p.Required == nil || *p.Required != true { + t.Errorf("Required = %v, want true", p.Required) + } + if p.EnumValues == nil || len(*p.EnumValues) != 2 { + t.Fatalf("expected 2 enum values, got %v", p.EnumValues) + } + if (*p.EnumValues)[0] != "eastus" || (*p.EnumValues)[1] != "westus" { + t.Errorf("EnumValues = %v, want [eastus westus]", *p.EnumValues) + } + }) + + t.Run("parameter without schema defaults to string kind", func(t *testing.T) { + t.Parallel() + params := map[string]OpenApiParameter{ + "name": { + Description: "Agent name", + Required: false, + }, + } + + got, err := ConvertParameters(params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got == nil { + t.Fatal("expected non-nil PropertySchema") + } + p := got.Properties[0] + if p.Kind != "string" { + t.Errorf("Kind = %q, want %q (default)", p.Kind, "string") + } + if p.EnumValues != nil { + t.Errorf("EnumValues = %v, want nil", p.EnumValues) + } + }) + + t.Run("parameter with example sets default", func(t *testing.T) { + t.Parallel() + params := map[string]OpenApiParameter{ + "timeout": { + Description: "Timeout in seconds", + Example: 30, + Schema: &OpenApiSchema{Type: "integer"}, + }, + } + + got, err := ConvertParameters(params) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + p := got.Properties[0] + if p.Default == nil { + t.Fatal("expected Default to be set from Example") + } + if *p.Default != 30 { + t.Errorf("Default = %v, want 30", *p.Default) + } + }) +} + +// --------------------------------------------------------------------------- +// MergeManifestIntoAgentDefinition +// --------------------------------------------------------------------------- + +func TestMergeManifestIntoAgentDefinition(t *testing.T) { + t.Parallel() + + t.Run("fills empty name from manifest", func(t *testing.T) { + t.Parallel() + manifest := &Manifest{ + Name: "manifest-agent", + DisplayName: "Manifest Agent", + Description: "A description", + } + agentDef := &agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + } + + result := MergeManifestIntoAgentDefinition(manifest, agentDef) + + if result.Name != "manifest-agent" { + t.Errorf("Name = %q, want %q", result.Name, "manifest-agent") + } + }) + + t.Run("does not overwrite existing name", func(t *testing.T) { + t.Parallel() + manifest := &Manifest{ + Name: "manifest-name", + } + agentDef := &agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + Name: "existing-name", + } + + result := MergeManifestIntoAgentDefinition(manifest, agentDef) + + if result.Name != "existing-name" { + t.Errorf("Name = %q, want %q", result.Name, "existing-name") + } + }) + + t.Run("does not modify original agent definition", func(t *testing.T) { + t.Parallel() + manifest := &Manifest{ + Name: "new-name", + } + agentDef := &agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + } + + _ = MergeManifestIntoAgentDefinition(manifest, agentDef) + + if agentDef.Name != "" { + t.Errorf("original AgentDefinition.Name was modified to %q", agentDef.Name) + } + }) + + t.Run("preserves kind when already set", func(t *testing.T) { + t.Parallel() + manifest := &Manifest{ + Name: "test", + } + agentDef := &agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + Name: "keep", + } + + result := MergeManifestIntoAgentDefinition(manifest, agentDef) + + if result.Kind != agent_yaml.AgentKindPrompt { + t.Errorf("Kind = %q, want %q", result.Kind, agent_yaml.AgentKindPrompt) + } + }) +} + +// --------------------------------------------------------------------------- +// injectParameterValues +// --------------------------------------------------------------------------- + +func TestInjectParameterValues(t *testing.T) { + t.Parallel() + + t.Run("replaces {{param}} style", func(t *testing.T) { + t.Parallel() + template := "Hello {{name}}, welcome to {{place}}!" + values := ParameterValues{ + "name": "Alice", + "place": "Wonderland", + } + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := "Hello Alice, welcome to Wonderland!" + if string(got) != want { + t.Errorf("got %q, want %q", string(got), want) + } + }) + + t.Run("replaces {{ param }} style with spaces", func(t *testing.T) { + t.Parallel() + template := "Value is {{ apiKey }}" + values := ParameterValues{ + "apiKey": "secret-123", + } + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := "Value is secret-123" + if string(got) != want { + t.Errorf("got %q, want %q", string(got), want) + } + }) + + t.Run("replaces both styles in same template", func(t *testing.T) { + t.Parallel() + template := "{{key1}} and {{ key1 }}" + values := ParameterValues{ + "key1": "replaced", + } + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := "replaced and replaced" + if string(got) != want { + t.Errorf("got %q, want %q", string(got), want) + } + }) + + t.Run("no placeholders returns unchanged", func(t *testing.T) { + t.Parallel() + template := "no placeholders here" + values := ParameterValues{"key": "val"} + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != template { + t.Errorf("got %q, want %q", string(got), template) + } + }) + + t.Run("empty parameter values returns unchanged", func(t *testing.T) { + t.Parallel() + template := "Hello {{name}}" + values := ParameterValues{} + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Unresolved placeholders remain but no error + if string(got) != template { + t.Errorf("got %q, want %q", string(got), template) + } + }) + + t.Run("non-string value is converted via Sprintf", func(t *testing.T) { + t.Parallel() + template := "count={{count}}" + values := ParameterValues{"count": 42} + + got, err := injectParameterValues(template, values) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := "count=42" + if string(got) != want { + t.Errorf("got %q, want %q", string(got), want) + } + }) + + t.Run("empty template returns empty", func(t *testing.T) { + t.Parallel() + got, err := injectParameterValues("", ParameterValues{"k": "v"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != "" { + t.Errorf("got %q, want empty", string(got)) + } + }) +} + +// --------------------------------------------------------------------------- +// convertFloat32ToFloat64 +// --------------------------------------------------------------------------- + +func TestConvertFloat32ToFloat64(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + got := convertFloat32ToFloat64(nil) + if got != nil { + t.Errorf("expected nil, got %v", *got) + } + }) + + t.Run("converts value", func(t *testing.T) { + t.Parallel() + f32 := float32(0.75) + got := convertFloat32ToFloat64(&f32) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != float64(f32) { + t.Errorf("got %v, want %v", *got, float64(f32)) + } + }) + + t.Run("zero value", func(t *testing.T) { + t.Parallel() + f32 := float32(0) + got := convertFloat32ToFloat64(&f32) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != 0 { + t.Errorf("got %v, want 0", *got) + } + }) +} + +// --------------------------------------------------------------------------- +// convertInt32ToInt +// --------------------------------------------------------------------------- + +func TestConvertInt32ToInt(t *testing.T) { + t.Parallel() + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + got := convertInt32ToInt(nil) + if got != nil { + t.Errorf("expected nil, got %v", *got) + } + }) + + t.Run("converts value", func(t *testing.T) { + t.Parallel() + i32 := int32(42) + got := convertInt32ToInt(&i32) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != 42 { + t.Errorf("got %d, want 42", *got) + } + }) + + t.Run("zero value", func(t *testing.T) { + t.Parallel() + i32 := int32(0) + got := convertInt32ToInt(&i32) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != 0 { + t.Errorf("got %d, want 0", *got) + } + }) +} + +// --------------------------------------------------------------------------- +// convertToPropertySchema +// --------------------------------------------------------------------------- + +func TestConvertToPropertySchema(t *testing.T) { + t.Parallel() + + t.Run("nil input returns empty properties", func(t *testing.T) { + t.Parallel() + got := convertToPropertySchema(nil) + if len(got.Properties) != 0 { + t.Errorf("expected 0 properties, got %d", len(got.Properties)) + } + }) + + t.Run("non-nil input returns empty properties", func(t *testing.T) { + t.Parallel() + // Current implementation is a placeholder that always returns empty properties + got := convertToPropertySchema(map[string]any{"key": "value"}) + if len(got.Properties) != 0 { + t.Errorf("expected 0 properties, got %d", len(got.Properties)) + } + }) +} From 2209279e4206e57c115b073e748ac041a40c3b5e Mon Sep 17 00:00:00 2001 From: trangevi Date: Fri, 10 Apr 2026 14:51:27 -0700 Subject: [PATCH 12/14] Revert "Add unit tests and testdata for azure.ai.agents extension (#7634)" This reverts commit af85a8fd218444b2b0521080483ef4af42bbcd61. --- .../internal/cmd/agent_context_test.go | 23 - .../internal/cmd/init_copy_test.go | 119 -- .../pkg/agents/agent_api/models_test.go | 1036 -------------- .../pkg/agents/agent_yaml/map_test.go | 1199 ----------------- .../agent_yaml/testdata/hosted-agent.yaml | 9 - .../testdata/invalid-empty-template.yaml | 1 - .../agent_yaml/testdata/invalid-no-kind.yaml | 4 - .../agent_yaml/testdata/invalid-no-model.yaml | 4 - .../agent_yaml/testdata/mcp-tools-agent.yaml | 18 - .../testdata/prompt-agent-full.yaml | 39 - .../testdata/prompt-agent-minimal.yaml | 5 - .../pkg/agents/agent_yaml/testdata_test.go | 286 ---- .../pkg/agents/registry_api/helpers_test.go | 863 ------------ 13 files changed, 3606 deletions(-) delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go delete mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go deleted file mode 100644 index 3df51697171..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_context_test.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package cmd - -import "testing" - -func TestBuildAgentEndpoint_Cases(t *testing.T) { - t.Parallel() - tests := []struct{ account, project, want string }{ - {"myaccount", "myproject", "https://myaccount.services.ai.azure.com/api/projects/myproject"}, - {"a", "b", "https://a.services.ai.azure.com/api/projects/b"}, - } - for _, tt := range tests { - t.Run(tt.account+"/"+tt.project, func(t *testing.T) { - t.Parallel() - got := buildAgentEndpoint(tt.account, tt.project) - if got != tt.want { - t.Errorf("got %q, want %q", got, tt.want) - } - }) - } -} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go deleted file mode 100644 index c6240acacf5..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_copy_test.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package cmd - -import ( - "os" - "path/filepath" - "strings" - "testing" -) - -func TestCopyDirectory(t *testing.T) { - t.Parallel() - - t.Run("happy_path", func(t *testing.T) { - t.Parallel() - src := t.TempDir() - - // Create a small tree: file.txt, sub/nested.txt - if err := os.WriteFile(filepath.Join(src, "file.txt"), []byte("hello"), 0644); err != nil { - t.Fatal(err) - } - subDir := filepath.Join(src, "sub") - if err := os.MkdirAll(subDir, 0755); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(subDir, "nested.txt"), []byte("world"), 0644); err != nil { - t.Fatal(err) - } - - dst := filepath.Join(t.TempDir(), "out") - if err := copyDirectory(src, dst); err != nil { - t.Fatal(err) - } - - // Verify top-level file - assertFileContents(t, filepath.Join(dst, "file.txt"), "hello") - // Verify nested file - assertFileContents(t, filepath.Join(dst, "sub", "nested.txt"), "world") - }) - - t.Run("same_path_noop", func(t *testing.T) { - t.Parallel() - dir := t.TempDir() - if err := copyDirectory(dir, dir); err != nil { - t.Fatalf("expected nil error for same path, got %v", err) - } - }) - - t.Run("subpath_error", func(t *testing.T) { - t.Parallel() - src := t.TempDir() - dst := filepath.Join(src, "child") - if err := os.MkdirAll(dst, 0755); err != nil { - t.Fatal(err) - } - - err := copyDirectory(src, dst) - if err == nil { - t.Fatal("expected error when dst is subpath of src") - } - if !strings.Contains(err.Error(), "refusing to copy") { - t.Errorf("unexpected error message: %v", err) - } - }) - - t.Run("missing_source_error", func(t *testing.T) { - t.Parallel() - src := filepath.Join(t.TempDir(), "nonexistent") - dst := t.TempDir() - - err := copyDirectory(src, dst) - if err == nil { - t.Fatal("expected error for missing source") - } - }) -} - -func TestCopyFile(t *testing.T) { - t.Parallel() - - t.Run("happy_path", func(t *testing.T) { - t.Parallel() - src := filepath.Join(t.TempDir(), "src.txt") - if err := os.WriteFile(src, []byte("data"), 0644); err != nil { - t.Fatal(err) - } - - dst := filepath.Join(t.TempDir(), "dst.txt") - if err := copyFile(src, dst); err != nil { - t.Fatal(err) - } - assertFileContents(t, dst, "data") - }) - - t.Run("missing_source_error", func(t *testing.T) { - t.Parallel() - src := filepath.Join(t.TempDir(), "nope.txt") - dst := filepath.Join(t.TempDir(), "dst.txt") - - if err := copyFile(src, dst); err == nil { - t.Fatal("expected error for missing source file") - } - }) -} - -// assertFileContents is a test helper that reads a file and compares its contents. -func assertFileContents(t *testing.T, path, want string) { - t.Helper() - data, err := os.ReadFile(path) - if err != nil { - t.Fatalf("reading %s: %v", path, err) - } - if got := string(data); got != want { - t.Errorf("file %s: got %q, want %q", path, got, want) - } -} - diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go deleted file mode 100644 index 487d46305ae..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models_test.go +++ /dev/null @@ -1,1036 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package agent_api - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestCreateAgentRequest_RoundTrip(t *testing.T) { - t.Parallel() - - original := CreateAgentRequest{ - Name: "test-agent", - CreateAgentVersionRequest: CreateAgentVersionRequest{ - Description: new("A test agent"), - Metadata: map[string]string{"env": "test"}, - Definition: PromptAgentDefinition{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - RaiConfig: &RaiConfig{RaiPolicyName: "default"}, - }, - Model: "gpt-4o", - Instructions: new("You are helpful"), - Temperature: new(float32(0.7)), - TopP: new(float32(0.9)), - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - // Verify JSON tag names - s := string(data) - for _, field := range []string{`"name"`, `"description"`, `"metadata"`, `"definition"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s, got: %s", field, s) - } - } - - var got CreateAgentRequest - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Name != original.Name { - t.Errorf("Name = %q, want %q", got.Name, original.Name) - } - if got.Description == nil || *got.Description != *original.Description { - t.Errorf("Description mismatch") - } - if got.Metadata["env"] != "test" { - t.Errorf("Metadata[env] = %q, want %q", got.Metadata["env"], "test") - } -} - -func TestAgentObject_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentObject{ - Object: "agent", - ID: "agent-123", - Name: "my-agent", - Versions: struct { - Latest AgentVersionObject `json:"latest"` - }{ - Latest: AgentVersionObject{ - Object: "agent_version", - ID: "ver-1", - Name: "my-agent", - Version: "1", - Description: new("version one"), - Metadata: map[string]string{"release": "stable"}, - CreatedAt: 1700000000, - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"object"`, `"id"`, `"name"`, `"versions"`, `"latest"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentObject - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.ID != original.ID { - t.Errorf("ID = %q, want %q", got.ID, original.ID) - } - if got.Versions.Latest.Version != "1" { - t.Errorf("Latest.Version = %q, want %q", got.Versions.Latest.Version, "1") - } - if got.Versions.Latest.CreatedAt != 1700000000 { - t.Errorf("Latest.CreatedAt = %d, want %d", got.Versions.Latest.CreatedAt, int64(1700000000)) - } -} - -func TestAgentContainerObject_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentContainerObject{ - Object: "container", - ID: "ctr-1", - Status: AgentContainerStatusRunning, - MaxReplicas: new(int32(3)), - MinReplicas: new(int32(1)), - ErrorMessage: new("partial failure"), - CreatedAt: "2024-01-01T00:00:00Z", - UpdatedAt: "2024-06-01T00:00:00Z", - Container: &AgentContainerDetails{ - HealthState: "healthy", - ProvisioningState: "Succeeded", - State: "Running", - UpdatedOn: "2024-06-01T00:00:00Z", - Replicas: []AgentContainerReplicaState{ - {Name: "replica-0", State: "Running", ContainerState: "started"}, - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{ - `"max_replicas"`, `"min_replicas"`, `"error_message"`, - `"created_at"`, `"updated_at"`, `"container"`, - `"health_state"`, `"provisioning_state"`, `"container_state"`, - } { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentContainerObject - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Status != AgentContainerStatusRunning { - t.Errorf("Status = %q, want %q", got.Status, AgentContainerStatusRunning) - } - if got.MaxReplicas == nil || *got.MaxReplicas != 3 { - t.Error("MaxReplicas mismatch") - } - if got.MinReplicas == nil || *got.MinReplicas != 1 { - t.Error("MinReplicas mismatch") - } - if got.ErrorMessage == nil || *got.ErrorMessage != "partial failure" { - t.Error("ErrorMessage mismatch") - } - if got.Container == nil || len(got.Container.Replicas) != 1 { - t.Error("Container.Replicas mismatch") - } -} - -func TestPromptAgentDefinition_RoundTrip(t *testing.T) { - t.Parallel() - - original := PromptAgentDefinition{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - RaiConfig: &RaiConfig{RaiPolicyName: "strict"}, - }, - Model: "gpt-4o", - Instructions: new("Be concise"), - Temperature: new(float32(0.5)), - TopP: new(float32(0.95)), - Reasoning: &Reasoning{Effort: "high"}, - Text: &ResponseTextFormatConfiguration{Type: "text"}, - StructuredInputs: map[string]StructuredInputDefinition{ - "query": { - Description: new("user query"), - Required: new(true), - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{ - `"kind"`, `"model"`, `"instructions"`, `"temperature"`, - `"top_p"`, `"reasoning"`, `"text"`, `"structured_inputs"`, - `"rai_config"`, `"rai_policy_name"`, - } { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got PromptAgentDefinition - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Kind != AgentKindPrompt { - t.Errorf("Kind = %q, want %q", got.Kind, AgentKindPrompt) - } - if got.Model != "gpt-4o" { - t.Errorf("Model = %q, want %q", got.Model, "gpt-4o") - } - if got.Instructions == nil || *got.Instructions != "Be concise" { - t.Error("Instructions mismatch") - } - if got.Temperature == nil || *got.Temperature != 0.5 { - t.Error("Temperature mismatch") - } - if got.Reasoning == nil || got.Reasoning.Effort != "high" { - t.Error("Reasoning mismatch") - } - if si, ok := got.StructuredInputs["query"]; !ok || si.Description == nil || *si.Description != "user query" { - t.Error("StructuredInputs mismatch") - } -} - -func TestHostedAgentDefinition_RoundTrip(t *testing.T) { - t.Parallel() - - original := HostedAgentDefinition{ - AgentDefinition: AgentDefinition{Kind: AgentKindHosted}, - ContainerProtocolVersions: []ProtocolVersionRecord{ - {Protocol: AgentProtocolResponses, Version: "2024-07-01"}, - }, - CPU: "1.0", - Memory: "2Gi", - EnvironmentVariables: map[string]string{"LOG_LEVEL": "debug"}, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{ - `"container_protocol_versions"`, `"cpu"`, `"memory"`, `"environment_variables"`, - } { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got HostedAgentDefinition - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Kind != AgentKindHosted { - t.Errorf("Kind = %q, want %q", got.Kind, AgentKindHosted) - } - if len(got.ContainerProtocolVersions) != 1 || got.ContainerProtocolVersions[0].Version != "2024-07-01" { - t.Error("ContainerProtocolVersions mismatch") - } - if got.EnvironmentVariables["LOG_LEVEL"] != "debug" { - t.Error("EnvironmentVariables mismatch") - } -} - -func TestImageBasedHostedAgentDefinition_RoundTrip(t *testing.T) { - t.Parallel() - - original := ImageBasedHostedAgentDefinition{ - HostedAgentDefinition: HostedAgentDefinition{ - AgentDefinition: AgentDefinition{Kind: AgentKindHosted}, - ContainerProtocolVersions: []ProtocolVersionRecord{ - {Protocol: AgentProtocolActivityProtocol, Version: "1.0"}, - }, - CPU: "0.5", - Memory: "1Gi", - }, - Image: "myregistry.azurecr.io/agent:latest", - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"image"`) { - t.Error("expected JSON to contain \"image\"") - } - - var got ImageBasedHostedAgentDefinition - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Image != original.Image { - t.Errorf("Image = %q, want %q", got.Image, original.Image) - } - if got.CPU != "0.5" { - t.Errorf("CPU = %q, want %q", got.CPU, "0.5") - } -} - -func TestAgentVersionObject_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentVersionObject{ - Object: "agent_version", - ID: "ver-abc", - Name: "my-agent", - Version: "3", - Description: new("third version"), - Metadata: map[string]string{"stage": "prod"}, - CreatedAt: 1710000000, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"object"`, `"id"`, `"version"`, `"created_at"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentVersionObject - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Version != "3" { - t.Errorf("Version = %q, want %q", got.Version, "3") - } - if got.CreatedAt != 1710000000 { - t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, int64(1710000000)) - } - if got.Metadata["stage"] != "prod" { - t.Errorf("Metadata[stage] = %q, want %q", got.Metadata["stage"], "prod") - } -} - -func TestDeleteAgentResponse_RoundTrip(t *testing.T) { - t.Parallel() - - original := DeleteAgentResponse{ - Object: "agent", - Name: "old-agent", - Deleted: true, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - var got DeleteAgentResponse - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Name != "old-agent" { - t.Errorf("Name = %q, want %q", got.Name, "old-agent") - } - if !got.Deleted { - t.Error("Deleted = false, want true") - } -} - -func TestDeleteAgentVersionResponse_RoundTrip(t *testing.T) { - t.Parallel() - - original := DeleteAgentVersionResponse{ - Object: "agent_version", - Name: "my-agent", - Version: "2", - Deleted: true, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"version"`) { - t.Error("expected JSON to contain \"version\"") - } - - var got DeleteAgentVersionResponse - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Version != "2" { - t.Errorf("Version = %q, want %q", got.Version, "2") - } - if !got.Deleted { - t.Error("Deleted = false, want true") - } -} - -func TestAgentEventHandlerRequest_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentEventHandlerRequest{ - Name: "eval-handler", - Metadata: map[string]string{"purpose": "eval"}, - EventTypes: []AgentEventType{AgentEventTypeResponseCompleted}, - Filter: &AgentEventHandlerFilter{ - AgentVersions: []string{"v1", "v2"}, - }, - Destination: AgentEventHandlerDestination{ - Type: AgentEventHandlerDestinationTypeEvals, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"event_types"`, `"filter"`, `"destination"`, `"agent_versions"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentEventHandlerRequest - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Name != "eval-handler" { - t.Errorf("Name = %q, want %q", got.Name, "eval-handler") - } - if len(got.EventTypes) != 1 || got.EventTypes[0] != AgentEventTypeResponseCompleted { - t.Error("EventTypes mismatch") - } - if got.Filter == nil || len(got.Filter.AgentVersions) != 2 { - t.Error("Filter.AgentVersions mismatch") - } -} - -func TestAgentEventHandlerObject_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentEventHandlerObject{ - Object: "event_handler", - ID: "eh-1", - Name: "my-handler", - Metadata: map[string]string{"team": "platform"}, - CreatedAt: 1720000000, - EventTypes: []AgentEventType{AgentEventTypeResponseCompleted}, - Destination: AgentEventHandlerDestination{ - Type: AgentEventHandlerDestinationTypeEvals, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - var got AgentEventHandlerObject - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.ID != "eh-1" { - t.Errorf("ID = %q, want %q", got.ID, "eh-1") - } - if got.CreatedAt != 1720000000 { - t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, int64(1720000000)) - } -} - -func TestFunctionTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := FunctionTool{ - Tool: Tool{Type: ToolTypeFunction}, - Name: "get_weather", - Description: new("Gets weather data"), - Parameters: map[string]any{"type": "object"}, - Strict: new(true), - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"type"`, `"name"`, `"description"`, `"parameters"`, `"strict"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got FunctionTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Type != ToolTypeFunction { - t.Errorf("Type = %q, want %q", got.Type, ToolTypeFunction) - } - if got.Name != "get_weather" { - t.Errorf("Name = %q, want %q", got.Name, "get_weather") - } - if got.Strict == nil || !*got.Strict { - t.Error("Strict mismatch") - } -} - -func TestMCPTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := MCPTool{ - Tool: Tool{Type: ToolTypeMCP}, - ServerLabel: "my-server", - ServerURL: "https://mcp.example.com", - Headers: map[string]string{"Authorization": "Bearer tok"}, - ProjectConnectionID: new("conn-abc"), - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"server_label"`, `"server_url"`, `"project_connection_id"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got MCPTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.ServerLabel != "my-server" { - t.Errorf("ServerLabel = %q, want %q", got.ServerLabel, "my-server") - } - if got.ServerURL != "https://mcp.example.com" { - t.Errorf("ServerURL = %q, want %q", got.ServerURL, "https://mcp.example.com") - } - if got.ProjectConnectionID == nil || *got.ProjectConnectionID != "conn-abc" { - t.Error("ProjectConnectionID mismatch") - } -} - -func TestFileSearchTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := FileSearchTool{ - Tool: Tool{Type: ToolTypeFileSearch}, - VectorStoreIds: []string{"vs-1", "vs-2"}, - MaxNumResults: new(int32(10)), - RankingOptions: &RankingOptions{ - Ranker: new("auto"), - ScoreThreshold: new(float32(0.8)), - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{ - `"vector_store_ids"`, `"max_num_results"`, `"ranking_options"`, - `"ranker"`, `"score_threshold"`, - } { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got FileSearchTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if len(got.VectorStoreIds) != 2 { - t.Errorf("VectorStoreIds length = %d, want 2", len(got.VectorStoreIds)) - } - if got.MaxNumResults == nil || *got.MaxNumResults != 10 { - t.Error("MaxNumResults mismatch") - } - if got.RankingOptions == nil || got.RankingOptions.Ranker == nil || *got.RankingOptions.Ranker != "auto" { - t.Error("RankingOptions.Ranker mismatch") - } -} - -func TestWebSearchPreviewTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := WebSearchPreviewTool{ - Tool: Tool{Type: ToolTypeWebSearchPreview}, - SearchContextSize: new("medium"), - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"search_context_size"`) { - t.Error("expected JSON to contain \"search_context_size\"") - } - - var got WebSearchPreviewTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Type != ToolTypeWebSearchPreview { - t.Errorf("Type = %q, want %q", got.Type, ToolTypeWebSearchPreview) - } - if got.SearchContextSize == nil || *got.SearchContextSize != "medium" { - t.Error("SearchContextSize mismatch") - } -} - -func TestCodeInterpreterTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := CodeInterpreterTool{ - Tool: Tool{Type: ToolTypeCodeInterpreter}, - Container: "container-id-123", - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"container"`) { - t.Error("expected JSON to contain \"container\"") - } - - var got CodeInterpreterTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Type != ToolTypeCodeInterpreter { - t.Errorf("Type = %q, want %q", got.Type, ToolTypeCodeInterpreter) - } - // Container is `any`, so after round-trip it comes back as string - if got.Container != "container-id-123" { - t.Errorf("Container = %v, want %q", got.Container, "container-id-123") - } -} - -func TestBingGroundingAgentTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := BingGroundingAgentTool{ - Tool: Tool{Type: ToolTypeBingGrounding}, - BingGrounding: BingGroundingSearchToolParameters{ - ProjectConnections: ToolProjectConnectionList{ - ProjectConnections: []ToolProjectConnection{{ID: "conn-1"}}, - }, - SearchConfigurations: []BingGroundingSearchConfiguration{ - { - ProjectConnectionID: "conn-1", - Market: new("en-US"), - }, - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"bing_grounding"`) { - t.Error("expected JSON to contain \"bing_grounding\"") - } - if !strings.Contains(s, `"project_connections"`) { - t.Error("expected JSON to contain \"project_connections\"") - } - - var got BingGroundingAgentTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if len(got.BingGrounding.ProjectConnections.ProjectConnections) != 1 { - t.Error("ProjectConnections length mismatch") - } - if len(got.BingGrounding.SearchConfigurations) != 1 { - t.Error("SearchConfigurations length mismatch") - } -} - -func TestOpenApiAgentTool_RoundTrip(t *testing.T) { - t.Parallel() - - original := OpenApiAgentTool{ - Tool: Tool{Type: ToolTypeOpenAPI}, - OpenAPI: OpenApiFunctionDefinition{ - Name: "petstore", - Description: new("Pet store API"), - Spec: map[string]any{"openapi": "3.0.0"}, - Auth: OpenApiAuthDetails{ - Type: OpenApiAuthTypeAnonymous, - }, - DefaultParams: []string{"api_version=v1"}, - Functions: []OpenApiFunction{ - { - Name: "listPets", - Description: new("List all pets"), - Parameters: map[string]any{"type": "object"}, - }, - }, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"openapi"`) { - t.Error("expected JSON to contain \"openapi\"") - } - - var got OpenApiAgentTool - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.OpenAPI.Name != "petstore" { - t.Errorf("OpenAPI.Name = %q, want %q", got.OpenAPI.Name, "petstore") - } - if got.OpenAPI.Auth.Type != OpenApiAuthTypeAnonymous { - t.Errorf("Auth.Type = %q, want %q", got.OpenAPI.Auth.Type, OpenApiAuthTypeAnonymous) - } - if len(got.OpenAPI.Functions) != 1 { - t.Errorf("Functions length = %d, want 1", len(got.OpenAPI.Functions)) - } -} - -func TestSessionFileInfo_RoundTrip(t *testing.T) { - t.Parallel() - - original := SessionFileInfo{ - Name: "data.csv", - Path: "/workspace/data.csv", - IsDirectory: false, - Size: 2048, - Mode: 0644, - LastModified: new("2024-06-15T10:30:00Z"), - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"name"`, `"path"`, `"is_dir"`, `"size"`, `"mode"`, `"modified_time"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got SessionFileInfo - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Name != "data.csv" { - t.Errorf("Name = %q, want %q", got.Name, "data.csv") - } - if got.IsDirectory { - t.Error("IsDirectory = true, want false") - } - if got.Size != 2048 { - t.Errorf("Size = %d, want %d", got.Size, int64(2048)) - } - if got.LastModified == nil || *got.LastModified != "2024-06-15T10:30:00Z" { - t.Error("LastModified mismatch") - } -} - -func TestSessionFileList_RoundTrip(t *testing.T) { - t.Parallel() - - original := SessionFileList{ - Path: "/workspace", - Entries: []SessionFileInfo{ - {Name: "file1.txt", Path: "/workspace/file1.txt", IsDirectory: false, Size: 100}, - {Name: "subdir", Path: "/workspace/subdir", IsDirectory: true}, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"entries"`) { - t.Error("expected JSON to contain \"entries\"") - } - - var got SessionFileList - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Path != "/workspace" { - t.Errorf("Path = %q, want %q", got.Path, "/workspace") - } - if len(got.Entries) != 2 { - t.Fatalf("Entries length = %d, want 2", len(got.Entries)) - } - if !got.Entries[1].IsDirectory { - t.Error("Entries[1].IsDirectory = false, want true") - } -} - -func TestEvalsDestination_RoundTrip(t *testing.T) { - t.Parallel() - - original := EvalsDestination{ - AgentEventHandlerDestination: AgentEventHandlerDestination{ - Type: AgentEventHandlerDestinationTypeEvals, - }, - EvalID: "eval-123", - MaxHourlyRuns: new(int32(10)), - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"eval_id"`, `"max_hourly_runs"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got EvalsDestination - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.EvalID != "eval-123" { - t.Errorf("EvalID = %q, want %q", got.EvalID, "eval-123") - } - if got.MaxHourlyRuns == nil || *got.MaxHourlyRuns != 10 { - t.Error("MaxHourlyRuns mismatch") - } -} - -func TestContainerAppAgentDefinition_RoundTrip(t *testing.T) { - t.Parallel() - - original := ContainerAppAgentDefinition{ - AgentDefinition: AgentDefinition{Kind: AgentKindContainerApp}, - ContainerProtocolVersions: []ProtocolVersionRecord{ - {Protocol: AgentProtocolInvocations, Version: "2024-01-01"}, - }, - ContainerAppResourceID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.App/containerApps/app", - IngressSubdomainSuffix: "myapp", - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{ - `"container_app_resource_id"`, `"ingress_subdomain_suffix"`, `"container_protocol_versions"`, - } { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got ContainerAppAgentDefinition - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Kind != AgentKindContainerApp { - t.Errorf("Kind = %q, want %q", got.Kind, AgentKindContainerApp) - } - if got.ContainerAppResourceID != original.ContainerAppResourceID { - t.Errorf("ContainerAppResourceID mismatch") - } -} - -func TestWorkflowDefinition_RoundTrip(t *testing.T) { - t.Parallel() - - original := WorkflowDefinition{ - AgentDefinition: AgentDefinition{Kind: AgentKindWorkflow}, - Trigger: map[string]any{"type": "schedule", "cron": "0 * * * *"}, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - if !strings.Contains(s, `"trigger"`) { - t.Error("expected JSON to contain \"trigger\"") - } - - var got WorkflowDefinition - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Kind != AgentKindWorkflow { - t.Errorf("Kind = %q, want %q", got.Kind, AgentKindWorkflow) - } - if got.Trigger["type"] != "schedule" { - t.Errorf("Trigger[type] = %v, want %q", got.Trigger["type"], "schedule") - } -} - -func TestAgentContainerOperationObject_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentContainerOperationObject{ - ID: "op-1", - AgentID: "agent-1", - AgentVersionID: "ver-1", - Status: AgentContainerOperationStatusSucceeded, - Error: &AgentContainerOperationError{ - Code: "E001", - Type: "runtime", - Message: "something went wrong", - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"agent_id"`, `"agent_version_id"`, `"status"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentContainerOperationObject - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if got.Status != AgentContainerOperationStatusSucceeded { - t.Errorf("Status = %q, want %q", got.Status, AgentContainerOperationStatusSucceeded) - } - if got.Error == nil || got.Error.Message != "something went wrong" { - t.Error("Error.Message mismatch") - } -} - -func TestCommonListObjectProperties_RoundTrip(t *testing.T) { - t.Parallel() - - original := AgentList{ - Data: []AgentObject{ - {Object: "agent", ID: "a1", Name: "agent-one"}, - }, - CommonListObjectProperties: CommonListObjectProperties{ - Object: "list", - FirstID: "a1", - LastID: "a1", - HasMore: false, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - s := string(data) - for _, field := range []string{`"first_id"`, `"last_id"`, `"has_more"`} { - if !strings.Contains(s, field) { - t.Errorf("expected JSON to contain %s", field) - } - } - - var got AgentList - if err := json.Unmarshal(data, &got); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if len(got.Data) != 1 || got.Data[0].ID != "a1" { - t.Error("Data mismatch") - } - if got.HasMore { - t.Error("HasMore = true, want false") - } -} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go deleted file mode 100644 index f53452355c9..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map_test.go +++ /dev/null @@ -1,1199 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package agent_yaml - -import ( - "math" - "strings" - "testing" - - "azureaiagent/internal/pkg/agents/agent_api" -) - -// --------------------------------------------------------------------------- -// constructBuildConfig -// --------------------------------------------------------------------------- - -func TestConstructBuildConfig_NoOptions(t *testing.T) { - t.Parallel() - cfg := constructBuildConfig() - if cfg == nil { - t.Fatal("expected non-nil config") - } - if cfg.ImageURL != "" { - t.Errorf("expected empty ImageURL, got %q", cfg.ImageURL) - } - if cfg.CPU != "" { - t.Errorf("expected empty CPU, got %q", cfg.CPU) - } - if cfg.Memory != "" { - t.Errorf("expected empty Memory, got %q", cfg.Memory) - } - if cfg.EnvironmentVariables != nil { - t.Errorf("expected nil EnvironmentVariables, got %v", cfg.EnvironmentVariables) - } -} - -func TestConstructBuildConfig_AllOptions(t *testing.T) { - t.Parallel() - cfg := constructBuildConfig( - WithImageURL("myregistry.azurecr.io/myimage:latest"), - WithCPU("2"), - WithMemory("4Gi"), - WithEnvironmentVariable("KEY1", "val1"), - WithEnvironmentVariables(map[string]string{"KEY2": "val2", "KEY3": "val3"}), - ) - if cfg.ImageURL != "myregistry.azurecr.io/myimage:latest" { - t.Errorf("ImageURL = %q", cfg.ImageURL) - } - if cfg.CPU != "2" { - t.Errorf("CPU = %q", cfg.CPU) - } - if cfg.Memory != "4Gi" { - t.Errorf("Memory = %q", cfg.Memory) - } - if len(cfg.EnvironmentVariables) != 3 { - t.Fatalf("expected 3 env vars, got %d", len(cfg.EnvironmentVariables)) - } - for _, k := range []string{"KEY1", "KEY2", "KEY3"} { - if _, ok := cfg.EnvironmentVariables[k]; !ok { - t.Errorf("missing env var %q", k) - } - } -} - -// --------------------------------------------------------------------------- -// WithEnvironmentVariable / WithEnvironmentVariables -// --------------------------------------------------------------------------- - -func TestWithEnvironmentVariable_InitializesMap(t *testing.T) { - t.Parallel() - cfg := &AgentBuildConfig{} - WithEnvironmentVariable("A", "1")(cfg) - if cfg.EnvironmentVariables["A"] != "1" { - t.Errorf("expected A=1, got %q", cfg.EnvironmentVariables["A"]) - } -} - -func TestWithEnvironmentVariables_MergesIntoExisting(t *testing.T) { - t.Parallel() - cfg := &AgentBuildConfig{EnvironmentVariables: map[string]string{"EXISTING": "x"}} - WithEnvironmentVariables(map[string]string{"NEW": "y"})(cfg) - if cfg.EnvironmentVariables["EXISTING"] != "x" { - t.Error("existing env var was lost") - } - if cfg.EnvironmentVariables["NEW"] != "y" { - t.Error("new env var not set") - } -} - -func TestWithEnvironmentVariables_InitializesNilMap(t *testing.T) { - t.Parallel() - cfg := &AgentBuildConfig{} - WithEnvironmentVariables(map[string]string{"K": "V"})(cfg) - if cfg.EnvironmentVariables["K"] != "V" { - t.Errorf("expected K=V") - } -} - -// --------------------------------------------------------------------------- -// convertIntToInt32 -// --------------------------------------------------------------------------- - -func TestConvertIntToInt32(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input *int - want *int32 - wantErr bool - }{ - { - name: "nil input", - input: nil, - want: nil, - }, - { - name: "zero", - input: new(0), - want: new(int32(0)), - }, - { - name: "positive value", - input: new(42), - want: new(int32(42)), - }, - { - name: "negative value", - input: new(-10), - want: new(int32(-10)), - }, - { - name: "max int32", - input: new(math.MaxInt32), - want: new(int32(math.MaxInt32)), - }, - { - name: "min int32", - input: new(math.MinInt32), - want: new(int32(math.MinInt32)), - }, - { - name: "overflow positive", - input: new(math.MaxInt32 + 1), - wantErr: true, - }, - { - name: "overflow negative", - input: new(math.MinInt32 - 1), - wantErr: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := convertIntToInt32(tc.input) - if tc.wantErr { - if err == nil { - t.Fatal("expected error, got nil") - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tc.want == nil { - if got != nil { - t.Fatalf("expected nil, got %v", *got) - } - return - } - if got == nil { - t.Fatal("expected non-nil result") - } - if *got != *tc.want { - t.Errorf("got %d, want %d", *got, *tc.want) - } - }) - } -} - -// --------------------------------------------------------------------------- -// convertFloat64ToFloat32 -// --------------------------------------------------------------------------- - -func TestConvertFloat64ToFloat32(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input *float64 - isNil bool - }{ - {name: "nil input", input: nil, isNil: true}, - {name: "zero", input: new(0.0)}, - {name: "typical temperature", input: new(0.7)}, - {name: "one", input: new(1.0)}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := convertFloat64ToFloat32(tc.input) - if tc.isNil { - if got != nil { - t.Fatalf("expected nil, got %v", *got) - } - return - } - if got == nil { - t.Fatal("expected non-nil result") - } - expected := float32(*tc.input) - if *got != expected { - t.Errorf("got %v, want %v", *got, expected) - } - }) - } -} - -// --------------------------------------------------------------------------- -// convertYamlToolToApiTool -// --------------------------------------------------------------------------- - -func TestConvertYamlToolToApiTool_Nil(t *testing.T) { - t.Parallel() - _, err := convertYamlToolToApiTool(nil) - if err == nil { - t.Fatal("expected error for nil tool") - } - if !strings.Contains(err.Error(), "nil") { - t.Errorf("error should mention nil, got: %s", err.Error()) - } -} - -func TestConvertYamlToolToApiTool_UnknownType(t *testing.T) { - t.Parallel() - _, err := convertYamlToolToApiTool("not-a-tool") - if err == nil { - t.Fatal("expected error for unknown tool type") - } - if !strings.Contains(err.Error(), "unsupported") { - t.Errorf("error should mention unsupported, got: %s", err.Error()) - } -} - -func TestConvertYamlToolToApiTool_Function(t *testing.T) { - t.Parallel() - desc := "adds two numbers" - yamlTool := FunctionTool{ - Tool: Tool{ - Name: "add", - Kind: ToolKindFunction, - Description: &desc, - }, - Parameters: PropertySchema{}, - Strict: new(true), - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ft, ok := result.(agent_api.FunctionTool) - if !ok { - t.Fatalf("expected agent_api.FunctionTool, got %T", result) - } - if ft.Tool.Type != agent_api.ToolTypeFunction { - t.Errorf("type = %q, want %q", ft.Tool.Type, agent_api.ToolTypeFunction) - } - if ft.Name != "add" { - t.Errorf("name = %q, want %q", ft.Name, "add") - } - if ft.Description == nil || *ft.Description != desc { - t.Errorf("description mismatch") - } - if ft.Strict == nil || !*ft.Strict { - t.Error("strict should be true") - } -} - -func TestConvertYamlToolToApiTool_FunctionNilDescription(t *testing.T) { - t.Parallel() - yamlTool := FunctionTool{ - Tool: Tool{ - Name: "noop", - Kind: ToolKindFunction, - }, - Parameters: PropertySchema{}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ft := result.(agent_api.FunctionTool) - if ft.Description != nil { - t.Errorf("expected nil description, got %v", ft.Description) - } - if ft.Strict != nil { - t.Errorf("expected nil strict, got %v", ft.Strict) - } -} - -func TestConvertYamlToolToApiTool_WebSearch(t *testing.T) { - t.Parallel() - yamlTool := WebSearchTool{ - Tool: Tool{Name: "websearch", Kind: ToolKindWebSearch}, - Options: map[string]any{ - "searchContextSize": "high", - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ws, ok := result.(agent_api.WebSearchPreviewTool) - if !ok { - t.Fatalf("expected agent_api.WebSearchPreviewTool, got %T", result) - } - if ws.Tool.Type != agent_api.ToolTypeWebSearchPreview { - t.Errorf("type = %q, want %q", ws.Tool.Type, agent_api.ToolTypeWebSearchPreview) - } - if ws.SearchContextSize == nil || *ws.SearchContextSize != "high" { - t.Errorf("searchContextSize mismatch") - } -} - -func TestConvertYamlToolToApiTool_WebSearchNoOptions(t *testing.T) { - t.Parallel() - yamlTool := WebSearchTool{ - Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ws := result.(agent_api.WebSearchPreviewTool) - if ws.UserLocation != nil { - t.Error("expected nil UserLocation") - } - if ws.SearchContextSize != nil { - t.Error("expected nil SearchContextSize") - } -} - -func TestConvertYamlToolToApiTool_BingGrounding(t *testing.T) { - t.Parallel() - bgParams := agent_api.BingGroundingSearchToolParameters{ - ProjectConnections: agent_api.ToolProjectConnectionList{ - ProjectConnections: []agent_api.ToolProjectConnection{{ID: "conn-1"}}, - }, - } - yamlTool := BingGroundingTool{ - Tool: Tool{Name: "bing", Kind: ToolKindBingGrounding}, - Options: map[string]any{ - "bingGrounding": bgParams, - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - bg, ok := result.(agent_api.BingGroundingAgentTool) - if !ok { - t.Fatalf("expected agent_api.BingGroundingAgentTool, got %T", result) - } - if bg.Tool.Type != agent_api.ToolTypeBingGrounding { - t.Errorf("type = %q, want %q", bg.Tool.Type, agent_api.ToolTypeBingGrounding) - } - if len(bg.BingGrounding.ProjectConnections.ProjectConnections) != 1 { - t.Errorf("expected 1 project connection") - } -} - -func TestConvertYamlToolToApiTool_BingGroundingNoOptions(t *testing.T) { - t.Parallel() - yamlTool := BingGroundingTool{ - Tool: Tool{Name: "bing", Kind: ToolKindBingGrounding}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - bg := result.(agent_api.BingGroundingAgentTool) - if bg.Tool.Type != agent_api.ToolTypeBingGrounding { - t.Errorf("type = %q", bg.Tool.Type) - } -} - -func TestConvertYamlToolToApiTool_FileSearch(t *testing.T) { - t.Parallel() - ranker := "default-2024-11-15" - threshold := 0.8 - maxResults := 10 - yamlTool := FileSearchTool{ - Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, - VectorStoreIds: []string{"vs-1", "vs-2"}, - MaximumResultCount: &maxResults, - Ranker: &ranker, - ScoreThreshold: &threshold, - Options: map[string]any{ - "filters": map[string]any{"type": "eq", "key": "status", "value": "active"}, - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - fs, ok := result.(agent_api.FileSearchTool) - if !ok { - t.Fatalf("expected agent_api.FileSearchTool, got %T", result) - } - if fs.Tool.Type != agent_api.ToolTypeFileSearch { - t.Errorf("type = %q", fs.Tool.Type) - } - if len(fs.VectorStoreIds) != 2 { - t.Errorf("expected 2 vector store ids, got %d", len(fs.VectorStoreIds)) - } - if fs.MaxNumResults == nil || *fs.MaxNumResults != 10 { - t.Errorf("MaxNumResults mismatch") - } - if fs.RankingOptions == nil { - t.Fatal("expected non-nil RankingOptions") - } - if fs.RankingOptions.Ranker == nil || *fs.RankingOptions.Ranker != ranker { - t.Errorf("ranker mismatch") - } - if fs.RankingOptions.ScoreThreshold == nil || *fs.RankingOptions.ScoreThreshold != float32(threshold) { - t.Errorf("score threshold mismatch") - } - if fs.Filters == nil { - t.Error("expected filters to be set") - } -} - -func TestConvertYamlToolToApiTool_FileSearchMinimal(t *testing.T) { - t.Parallel() - yamlTool := FileSearchTool{ - Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, - VectorStoreIds: []string{"vs-1"}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - fs := result.(agent_api.FileSearchTool) - if fs.MaxNumResults != nil { - t.Error("expected nil MaxNumResults") - } - if fs.RankingOptions != nil { - t.Error("expected nil RankingOptions when ranker and threshold are nil") - } -} - -func TestConvertYamlToolToApiTool_FileSearchOverflow(t *testing.T) { - t.Parallel() - overflow := math.MaxInt32 + 1 - yamlTool := FileSearchTool{ - Tool: Tool{Name: "fs", Kind: ToolKindFileSearch}, - VectorStoreIds: []string{"vs-1"}, - MaximumResultCount: &overflow, - } - - _, err := convertYamlToolToApiTool(yamlTool) - if err == nil { - t.Fatal("expected error for int32 overflow") - } - if !strings.Contains(err.Error(), "overflow") { - t.Errorf("error should mention overflow, got: %s", err.Error()) - } -} - -func TestConvertYamlToolToApiTool_MCP(t *testing.T) { - t.Parallel() - yamlTool := McpTool{ - Tool: Tool{Name: "mcp-server", Kind: ToolKindMcp}, - ServerName: "my-mcp-server", - Options: map[string]any{ - "serverUrl": "https://mcp.example.com", - "headers": map[string]string{"Authorization": "Bearer tok"}, - "allowedTools": []string{"tool_a", "tool_b"}, - "requireApproval": "always", - "projectConnectionId": "conn-123", - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - mcp, ok := result.(agent_api.MCPTool) - if !ok { - t.Fatalf("expected agent_api.MCPTool, got %T", result) - } - if mcp.Tool.Type != agent_api.ToolTypeMCP { - t.Errorf("type = %q", mcp.Tool.Type) - } - if mcp.ServerLabel != "my-mcp-server" { - t.Errorf("ServerLabel = %q", mcp.ServerLabel) - } - if mcp.ServerURL != "https://mcp.example.com" { - t.Errorf("ServerURL = %q", mcp.ServerURL) - } - if mcp.Headers["Authorization"] != "Bearer tok" { - t.Errorf("headers mismatch") - } - if mcp.ProjectConnectionID == nil || *mcp.ProjectConnectionID != "conn-123" { - t.Errorf("ProjectConnectionID mismatch") - } -} - -func TestConvertYamlToolToApiTool_MCPNoOptions(t *testing.T) { - t.Parallel() - yamlTool := McpTool{ - Tool: Tool{Name: "mcp", Kind: ToolKindMcp}, - ServerName: "srv", - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - mcp := result.(agent_api.MCPTool) - if mcp.ServerURL != "" { - t.Errorf("expected empty ServerURL, got %q", mcp.ServerURL) - } - if mcp.ProjectConnectionID != nil { - t.Error("expected nil ProjectConnectionID") - } -} - -func TestConvertYamlToolToApiTool_OpenApi(t *testing.T) { - t.Parallel() - openApiDef := agent_api.OpenApiFunctionDefinition{ - Name: "petstore", - Auth: agent_api.OpenApiAuthDetails{Type: agent_api.OpenApiAuthTypeAnonymous}, - } - yamlTool := OpenApiTool{ - Tool: Tool{Name: "petstore", Kind: ToolKindOpenApi}, - Options: map[string]any{ - "openapi": openApiDef, - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - oa, ok := result.(agent_api.OpenApiAgentTool) - if !ok { - t.Fatalf("expected agent_api.OpenApiAgentTool, got %T", result) - } - if oa.Tool.Type != agent_api.ToolTypeOpenAPI { - t.Errorf("type = %q", oa.Tool.Type) - } - if oa.OpenAPI.Name != "petstore" { - t.Errorf("OpenAPI.Name = %q", oa.OpenAPI.Name) - } -} - -func TestConvertYamlToolToApiTool_OpenApiNoOptions(t *testing.T) { - t.Parallel() - yamlTool := OpenApiTool{ - Tool: Tool{Name: "api", Kind: ToolKindOpenApi}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - oa := result.(agent_api.OpenApiAgentTool) - if oa.Tool.Type != agent_api.ToolTypeOpenAPI { - t.Errorf("type = %q", oa.Tool.Type) - } -} - -func TestConvertYamlToolToApiTool_CodeInterpreter(t *testing.T) { - t.Parallel() - yamlTool := CodeInterpreterTool{ - Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}, - Options: map[string]any{ - "container": "auto", - }, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ci, ok := result.(agent_api.CodeInterpreterTool) - if !ok { - t.Fatalf("expected agent_api.CodeInterpreterTool, got %T", result) - } - if ci.Tool.Type != agent_api.ToolTypeCodeInterpreter { - t.Errorf("type = %q", ci.Tool.Type) - } - if ci.Container != "auto" { - t.Errorf("Container = %v", ci.Container) - } -} - -func TestConvertYamlToolToApiTool_CodeInterpreterNoOptions(t *testing.T) { - t.Parallel() - yamlTool := CodeInterpreterTool{ - Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}, - } - - result, err := convertYamlToolToApiTool(yamlTool) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - ci := result.(agent_api.CodeInterpreterTool) - if ci.Container != nil { - t.Errorf("expected nil Container, got %v", ci.Container) - } -} - -// --------------------------------------------------------------------------- -// convertYamlToolsToApiTools -// --------------------------------------------------------------------------- - -func TestConvertYamlToolsToApiTools_MixedTools(t *testing.T) { - t.Parallel() - yamlTools := []any{ - FunctionTool{Tool: Tool{Name: "fn1", Kind: ToolKindFunction}, Parameters: PropertySchema{}}, - WebSearchTool{Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}}, - CodeInterpreterTool{Tool: Tool{Name: "ci", Kind: ToolKindCodeInterpreter}}, - } - - result := convertYamlToolsToApiTools(yamlTools) - if len(result) != 3 { - t.Fatalf("expected 3 tools, got %d", len(result)) - } - - if _, ok := result[0].(agent_api.FunctionTool); !ok { - t.Errorf("tool[0] should be FunctionTool, got %T", result[0]) - } - if _, ok := result[1].(agent_api.WebSearchPreviewTool); !ok { - t.Errorf("tool[1] should be WebSearchPreviewTool, got %T", result[1]) - } - if _, ok := result[2].(agent_api.CodeInterpreterTool); !ok { - t.Errorf("tool[2] should be CodeInterpreterTool, got %T", result[2]) - } -} - -func TestConvertYamlToolsToApiTools_SkipsUnsupported(t *testing.T) { - t.Parallel() - yamlTools := []any{ - FunctionTool{Tool: Tool{Name: "fn1", Kind: ToolKindFunction}, Parameters: PropertySchema{}}, - "unsupported-string-tool", - WebSearchTool{Tool: Tool{Name: "ws", Kind: ToolKindWebSearch}}, - } - - result := convertYamlToolsToApiTools(yamlTools) - if len(result) != 2 { - t.Fatalf("expected 2 tools (unsupported skipped), got %d", len(result)) - } -} - -func TestConvertYamlToolsToApiTools_Empty(t *testing.T) { - t.Parallel() - result := convertYamlToolsToApiTools([]any{}) - if result != nil { - t.Errorf("expected nil for empty input, got %v", result) - } -} - -// --------------------------------------------------------------------------- -// createAgentAPIRequest (common fields) -// --------------------------------------------------------------------------- - -func TestCreateAgentAPIRequest_AllFields(t *testing.T) { - t.Parallel() - desc := "A helpful agent" - meta := map[string]any{ - "authors": []any{"Alice", "Bob"}, - "version": "1.0", - } - agentDef := AgentDefinition{ - Kind: AgentKindPrompt, - Name: "my-agent", - Description: &desc, - Metadata: &meta, - } - - req, err := createAgentAPIRequest(agentDef, "placeholder-definition") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "my-agent" { - t.Errorf("Name = %q, want %q", req.Name, "my-agent") - } - if req.Description == nil || *req.Description != desc { - t.Errorf("Description mismatch") - } - if req.Metadata["authors"] != "Alice,Bob" { - t.Errorf("authors = %q, want %q", req.Metadata["authors"], "Alice,Bob") - } - if req.Metadata["version"] != "1.0" { - t.Errorf("version metadata = %q", req.Metadata["version"]) - } - if req.Definition != "placeholder-definition" { - t.Errorf("Definition mismatch") - } -} - -func TestCreateAgentAPIRequest_DefaultName(t *testing.T) { - t.Parallel() - agentDef := AgentDefinition{Kind: AgentKindPrompt} - - req, err := createAgentAPIRequest(agentDef, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "unspecified-agent-name" { - t.Errorf("Name = %q, want %q", req.Name, "unspecified-agent-name") - } -} - -func TestCreateAgentAPIRequest_NilMetadata(t *testing.T) { - t.Parallel() - agentDef := AgentDefinition{Kind: AgentKindPrompt, Name: "test"} - - req, err := createAgentAPIRequest(agentDef, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Metadata != nil { - t.Errorf("expected nil Metadata, got %v", req.Metadata) - } -} - -func TestCreateAgentAPIRequest_EmptyDescription(t *testing.T) { - t.Parallel() - empty := "" - agentDef := AgentDefinition{ - Kind: AgentKindPrompt, - Name: "test", - Description: &empty, - } - - req, err := createAgentAPIRequest(agentDef, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Description != nil { - t.Errorf("expected nil Description for empty string, got %v", req.Description) - } -} - -func TestCreateAgentAPIRequest_MetadataWithNonStringValues(t *testing.T) { - t.Parallel() - meta := map[string]any{ - "name": "test", - "numeric": 42, // non-string value should be skipped - } - agentDef := AgentDefinition{ - Kind: AgentKindPrompt, - Name: "test", - Metadata: &meta, - } - - req, err := createAgentAPIRequest(agentDef, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Metadata["name"] != "test" { - t.Errorf("string metadata missing") - } - if _, exists := req.Metadata["numeric"]; exists { - t.Errorf("non-string metadata should be skipped") - } -} - -func TestCreateAgentAPIRequest_AuthorsSingleAuthor(t *testing.T) { - t.Parallel() - meta := map[string]any{ - "authors": []any{"Solo"}, - } - agentDef := AgentDefinition{ - Kind: AgentKindPrompt, - Name: "test", - Metadata: &meta, - } - - req, err := createAgentAPIRequest(agentDef, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Metadata["authors"] != "Solo" { - t.Errorf("authors = %q, want %q", req.Metadata["authors"], "Solo") - } -} - -// --------------------------------------------------------------------------- -// CreatePromptAgentAPIRequest -// --------------------------------------------------------------------------- - -func TestCreatePromptAgentAPIRequest_FullConfig(t *testing.T) { - t.Parallel() - desc := "prompt agent" - instructions := "You are a helpful assistant." - temp := 0.7 - topP := 0.9 - - agent := PromptAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - Name: "my-prompt-agent", - Description: &desc, - }, - Model: Model{ - Id: "gpt-4o", - Options: &ModelOptions{ - Temperature: &temp, - TopP: &topP, - }, - }, - Instructions: &instructions, - Tools: &[]any{ - FunctionTool{ - Tool: Tool{Name: "calc", Kind: ToolKindFunction}, - Parameters: PropertySchema{}, - }, - }, - } - - req, err := CreatePromptAgentAPIRequest(agent, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "my-prompt-agent" { - t.Errorf("Name = %q", req.Name) - } - if req.Description == nil || *req.Description != desc { - t.Errorf("Description mismatch") - } - - promptDef, ok := req.Definition.(agent_api.PromptAgentDefinition) - if !ok { - t.Fatalf("Definition should be PromptAgentDefinition, got %T", req.Definition) - } - if promptDef.Kind != agent_api.AgentKindPrompt { - t.Errorf("Kind = %q", promptDef.Kind) - } - if promptDef.Model != "gpt-4o" { - t.Errorf("Model = %q", promptDef.Model) - } - if promptDef.Instructions == nil || *promptDef.Instructions != instructions { - t.Errorf("Instructions mismatch") - } - if promptDef.Temperature == nil || *promptDef.Temperature != float32(0.7) { - t.Errorf("Temperature mismatch") - } - if promptDef.TopP == nil || *promptDef.TopP != float32(0.9) { - t.Errorf("TopP mismatch") - } - if len(promptDef.Tools) != 1 { - t.Fatalf("expected 1 tool, got %d", len(promptDef.Tools)) - } -} - -func TestCreatePromptAgentAPIRequest_MissingModelId(t *testing.T) { - t.Parallel() - agent := PromptAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - Name: "bad-agent", - }, - Model: Model{Id: ""}, - Tools: &[]any{}, - } - - _, err := CreatePromptAgentAPIRequest(agent, nil) - if err == nil { - t.Fatal("expected error for missing model.id") - } - if !strings.Contains(err.Error(), "model.id") { - t.Errorf("error should mention model.id, got: %s", err.Error()) - } -} - -func TestCreatePromptAgentAPIRequest_NoOptions(t *testing.T) { - t.Parallel() - agent := PromptAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - Name: "simple-agent", - }, - Model: Model{Id: "gpt-4o-mini"}, - Tools: &[]any{}, - } - - req, err := CreatePromptAgentAPIRequest(agent, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - promptDef := req.Definition.(agent_api.PromptAgentDefinition) - if promptDef.Temperature != nil { - t.Errorf("expected nil Temperature, got %v", *promptDef.Temperature) - } - if promptDef.TopP != nil { - t.Errorf("expected nil TopP, got %v", *promptDef.TopP) - } - if promptDef.Instructions != nil { - t.Errorf("expected nil Instructions") - } -} - -func TestCreatePromptAgentAPIRequest_NilToolsSlice(t *testing.T) { - t.Parallel() - emptyTools := []any{} - agent := PromptAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - Name: "no-tools", - }, - Model: Model{Id: "gpt-4o"}, - Tools: &emptyTools, - } - - req, err := CreatePromptAgentAPIRequest(agent, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - promptDef := req.Definition.(agent_api.PromptAgentDefinition) - if promptDef.Tools != nil { - t.Errorf("expected nil Tools for empty input, got %v", promptDef.Tools) - } -} - -// --------------------------------------------------------------------------- -// CreateHostedAgentAPIRequest -// --------------------------------------------------------------------------- - -func TestCreateHostedAgentAPIRequest_FullConfig(t *testing.T) { - t.Parallel() - desc := "hosted agent" - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "my-hosted-agent", - Description: &desc, - }, - Protocols: []ProtocolVersionRecord{ - {Protocol: "responses", Version: "2.0.0"}, - {Protocol: "invocations", Version: "1.0.0"}, - }, - } - - buildConfig := &AgentBuildConfig{ - ImageURL: "myregistry.azurecr.io/agent:v1", - CPU: "4", - Memory: "8Gi", - EnvironmentVariables: map[string]string{"ENV1": "val1"}, - } - - req, err := CreateHostedAgentAPIRequest(agent, buildConfig) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "my-hosted-agent" { - t.Errorf("Name = %q", req.Name) - } - if req.Description == nil || *req.Description != desc { - t.Errorf("Description mismatch") - } - - imgDef, ok := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) - if !ok { - t.Fatalf("expected ImageBasedHostedAgentDefinition, got %T", req.Definition) - } - if imgDef.Kind != agent_api.AgentKindHosted { - t.Errorf("Kind = %q", imgDef.Kind) - } - if imgDef.Image != "myregistry.azurecr.io/agent:v1" { - t.Errorf("Image = %q", imgDef.Image) - } - if imgDef.CPU != "4" { - t.Errorf("CPU = %q", imgDef.CPU) - } - if imgDef.Memory != "8Gi" { - t.Errorf("Memory = %q", imgDef.Memory) - } - if imgDef.EnvironmentVariables["ENV1"] != "val1" { - t.Error("env var missing") - } - - // Verify protocol versions - if len(imgDef.ContainerProtocolVersions) != 2 { - t.Fatalf("expected 2 protocol versions, got %d", len(imgDef.ContainerProtocolVersions)) - } - if imgDef.ContainerProtocolVersions[0].Protocol != "responses" { - t.Errorf("protocol[0] = %q", imgDef.ContainerProtocolVersions[0].Protocol) - } - if imgDef.ContainerProtocolVersions[0].Version != "2.0.0" { - t.Errorf("version[0] = %q", imgDef.ContainerProtocolVersions[0].Version) - } -} - -func TestCreateHostedAgentAPIRequest_DefaultProtocols(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "default-protocols", - }, - } - buildConfig := &AgentBuildConfig{ImageURL: "img:latest"} - - req, err := CreateHostedAgentAPIRequest(agent, buildConfig) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) - if len(imgDef.ContainerProtocolVersions) != 1 { - t.Fatalf("expected 1 default protocol, got %d", len(imgDef.ContainerProtocolVersions)) - } - if imgDef.ContainerProtocolVersions[0].Protocol != agent_api.AgentProtocolResponses { - t.Errorf("default protocol = %q", imgDef.ContainerProtocolVersions[0].Protocol) - } - if imgDef.ContainerProtocolVersions[0].Version != "1.0.0" { - t.Errorf("default version = %q", imgDef.ContainerProtocolVersions[0].Version) - } -} - -func TestCreateHostedAgentAPIRequest_DefaultCPUAndMemory(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "defaults", - }, - } - buildConfig := &AgentBuildConfig{ImageURL: "img:latest"} - - req, err := CreateHostedAgentAPIRequest(agent, buildConfig) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) - if imgDef.CPU != "1" { - t.Errorf("default CPU = %q, want %q", imgDef.CPU, "1") - } - if imgDef.Memory != "2Gi" { - t.Errorf("default Memory = %q, want %q", imgDef.Memory, "2Gi") - } -} - -func TestCreateHostedAgentAPIRequest_MissingImageURL(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "no-image", - }, - } - - _, err := CreateHostedAgentAPIRequest(agent, &AgentBuildConfig{}) - if err == nil { - t.Fatal("expected error for missing image URL") - } - if !strings.Contains(err.Error(), "image URL") { - t.Errorf("error should mention image URL, got: %s", err.Error()) - } -} - -func TestCreateHostedAgentAPIRequest_NilBuildConfig(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "nil-config", - }, - } - - _, err := CreateHostedAgentAPIRequest(agent, nil) - if err == nil { - t.Fatal("expected error for nil build config (no image)") - } -} - -// --------------------------------------------------------------------------- -// CreateAgentAPIRequestFromDefinition (routing) -// --------------------------------------------------------------------------- - -func TestCreateAgentAPIRequestFromDefinition_PromptAgent(t *testing.T) { - t.Parallel() - agent := PromptAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindPrompt, - Name: "prompt-routed", - }, - Model: Model{Id: "gpt-4o"}, - Tools: &[]any{}, - } - - req, err := CreateAgentAPIRequestFromDefinition(agent) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "prompt-routed" { - t.Errorf("Name = %q", req.Name) - } - - _, ok := req.Definition.(agent_api.PromptAgentDefinition) - if !ok { - t.Fatalf("expected PromptAgentDefinition, got %T", req.Definition) - } -} - -func TestCreateAgentAPIRequestFromDefinition_HostedAgent(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "hosted-routed", - }, - } - - req, err := CreateAgentAPIRequestFromDefinition(agent, WithImageURL("img:latest")) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Name != "hosted-routed" { - t.Errorf("Name = %q", req.Name) - } - - _, ok := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) - if !ok { - t.Fatalf("expected ImageBasedHostedAgentDefinition, got %T", req.Definition) - } -} - -func TestCreateAgentAPIRequestFromDefinition_UnsupportedKind(t *testing.T) { - t.Parallel() - agent := AgentDefinition{ - Kind: "unknown", - Name: "bad-kind", - } - - _, err := CreateAgentAPIRequestFromDefinition(agent) - if err == nil { - t.Fatal("expected error for unsupported kind") - } - if !strings.Contains(err.Error(), "unsupported agent kind") { - t.Errorf("error should mention unsupported agent kind, got: %s", err.Error()) - } -} - -func TestCreateAgentAPIRequestFromDefinition_HostedWithBuildOptions(t *testing.T) { - t.Parallel() - agent := ContainerAgent{ - AgentDefinition: AgentDefinition{ - Kind: AgentKindHosted, - Name: "hosted-opts", - }, - Protocols: []ProtocolVersionRecord{ - {Protocol: "responses", Version: "1.0.0"}, - }, - } - - req, err := CreateAgentAPIRequestFromDefinition(agent, - WithImageURL("myimg:v2"), - WithCPU("2"), - WithMemory("4Gi"), - WithEnvironmentVariable("FOO", "bar"), - ) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - imgDef := req.Definition.(agent_api.ImageBasedHostedAgentDefinition) - if imgDef.Image != "myimg:v2" { - t.Errorf("Image = %q", imgDef.Image) - } - if imgDef.CPU != "2" { - t.Errorf("CPU = %q", imgDef.CPU) - } - if imgDef.Memory != "4Gi" { - t.Errorf("Memory = %q", imgDef.Memory) - } - if imgDef.EnvironmentVariables["FOO"] != "bar" { - t.Errorf("env var FOO missing or wrong") - } -} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml deleted file mode 100644 index a6a5314c78b..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/hosted-agent.yaml +++ /dev/null @@ -1,9 +0,0 @@ -template: - kind: hosted - name: hosted-test-agent - description: A hosted container agent for testing - protocols: - - protocol: responses - version: "1.0.0" - - protocol: invocations - version: "1.0.0" diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml deleted file mode 100644 index 2676f0d4033..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-empty-template.yaml +++ /dev/null @@ -1 +0,0 @@ -template: {} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml deleted file mode 100644 index 5ba3da54847..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-kind.yaml +++ /dev/null @@ -1,4 +0,0 @@ -template: - name: no-kind-agent - model: - id: gpt-4o diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml deleted file mode 100644 index 1f5f9536a65..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/invalid-no-model.yaml +++ /dev/null @@ -1,4 +0,0 @@ -template: - kind: prompt - name: no-model-agent - instructions: Some instructions diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml deleted file mode 100644 index 9e2a70ed8fa..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/mcp-tools-agent.yaml +++ /dev/null @@ -1,18 +0,0 @@ -template: - kind: prompt - name: mcp-tools-agent - description: Agent with MCP tool connections - model: - id: gpt-4o - tools: - - kind: mcp - name: github-mcp - connection: - kind: foundry - endpoint: https://api.githubcopilot.com/mcp/ - name: github-mcp-conn - url: https://api.githubcopilot.com/mcp/ - - kind: code_interpreter - name: code-runner - instructions: | - You have access to GitHub via MCP and can run code. diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml deleted file mode 100644 index b5643ea337a..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-full.yaml +++ /dev/null @@ -1,39 +0,0 @@ -template: - kind: prompt - name: full-prompt-agent - description: A fully configured prompt agent for testing - metadata: - authors: - - testauthor - tags: - - testing - - full - model: - id: gpt-4o - publisher: azure - options: - temperature: 0.8 - maxTokens: 4000 - topP: 0.95 - instructions: | - You are a helpful testing assistant. - Always respond in a structured format. - tools: - - kind: web_search - name: web-search - - kind: function - name: get_weather - description: Get weather for a location - parameters: - properties: - - name: location - kind: string - description: The city name - required: true - - name: unit - kind: string - description: Temperature unit - enumValues: - - celsius - - fahrenheit - default: celsius diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml deleted file mode 100644 index 7f598558883..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata/prompt-agent-minimal.yaml +++ /dev/null @@ -1,5 +0,0 @@ -template: - kind: prompt - name: minimal-agent - model: - id: gpt-4o-mini diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go deleted file mode 100644 index 0482af53abf..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/testdata_test.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package agent_yaml - -import ( - "os" - "path/filepath" - "strings" - "testing" - - "go.yaml.in/yaml/v3" -) - -// TestFixtures_ValidYAML verifies that valid YAML fixtures parse successfully -// and produce the expected agent kind and name via ExtractAgentDefinition. -func TestFixtures_ValidYAML(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - file string - wantKind AgentKind - wantName string - wantErrSubst string // if non-empty, expect this error instead of success - }{ - { - name: "hosted agent", - file: filepath.Join("testdata", "hosted-agent.yaml"), - wantKind: AgentKindHosted, - wantName: "hosted-test-agent", - }, - { - // Prompt agents are not currently supported by ExtractAgentDefinition. - // This test documents the current expected behavior. - name: "prompt agent minimal", - file: filepath.Join("testdata", "prompt-agent-minimal.yaml"), - wantErrSubst: "prompt agents not currently supported", - }, - { - name: "prompt agent full", - file: filepath.Join("testdata", "prompt-agent-full.yaml"), - wantErrSubst: "prompt agents not currently supported", - }, - { - name: "mcp tools agent", - file: filepath.Join("testdata", "mcp-tools-agent.yaml"), - wantErrSubst: "prompt agents not currently supported", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data, err := os.ReadFile(tc.file) - if err != nil { - t.Fatalf("failed to read fixture %s: %v", tc.file, err) - } - - agent, err := ExtractAgentDefinition(data) - - if tc.wantErrSubst != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tc.wantErrSubst) - } - if !strings.Contains(err.Error(), tc.wantErrSubst) { - t.Fatalf("expected error containing %q, got %q", tc.wantErrSubst, err.Error()) - } - return - } - - if err != nil { - t.Fatalf("ExtractAgentDefinition failed: %v", err) - } - - containerAgent, ok := agent.(ContainerAgent) - if !ok { - t.Fatalf("expected ContainerAgent, got %T", agent) - } - - if containerAgent.Kind != tc.wantKind { - t.Errorf("kind: got %q, want %q", containerAgent.Kind, tc.wantKind) - } - if containerAgent.Name != tc.wantName { - t.Errorf("name: got %q, want %q", containerAgent.Name, tc.wantName) - } - }) - } -} - -// TestFixtures_ValidatePromptAgents uses ValidateAgentDefinition to confirm -// that prompt agent fixtures have a structurally valid YAML schema, even though -// ExtractAgentDefinition does not yet support prompt agents. -func TestFixtures_ValidatePromptAgents(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - file string - }{ - {name: "prompt agent minimal", file: filepath.Join("testdata", "prompt-agent-minimal.yaml")}, - {name: "prompt agent full", file: filepath.Join("testdata", "prompt-agent-full.yaml")}, - {name: "mcp tools agent", file: filepath.Join("testdata", "mcp-tools-agent.yaml")}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data, err := os.ReadFile(tc.file) - if err != nil { - t.Fatalf("failed to read fixture %s: %v", tc.file, err) - } - - // Extract the template section to pass to ValidateAgentDefinition, - // which operates on template bytes rather than the full manifest. - templateBytes, err := extractTemplateBytes(data) - if err != nil { - t.Fatalf("failed to extract template bytes: %v", err) - } - - if err := ValidateAgentDefinition(templateBytes); err != nil { - t.Fatalf("ValidateAgentDefinition failed for valid fixture: %v", err) - } - }) - } -} - -// TestFixtures_InvalidYAML verifies that invalid YAML fixtures return appropriate errors. -func TestFixtures_InvalidYAML(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - file string - wantErrSubst string - }{ - { - name: "missing kind field", - file: filepath.Join("testdata", "invalid-no-kind.yaml"), - wantErrSubst: "template.kind must be one of", - }, - { - name: "prompt agent missing model", - file: filepath.Join("testdata", "invalid-no-model.yaml"), - wantErrSubst: "template.model.id is required", - }, - { - name: "empty template", - file: filepath.Join("testdata", "invalid-empty-template.yaml"), - wantErrSubst: "template field is empty", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data, err := os.ReadFile(tc.file) - if err != nil { - t.Fatalf("failed to read fixture %s: %v", tc.file, err) - } - - // For "empty template", ExtractAgentDefinition catches the error. - // For schema validation errors, ValidateAgentDefinition is used on the template bytes. - _, extractErr := ExtractAgentDefinition(data) - - if extractErr != nil && strings.Contains(extractErr.Error(), tc.wantErrSubst) { - return // error caught at extraction level - } - - // Try validation-level check for schema errors (no-kind, no-model). - templateBytes, err := extractTemplateBytes(data) - if err != nil { - // If we can't even extract template bytes but got an extraction error, that's fine. - if extractErr != nil { - t.Logf("ExtractAgentDefinition error: %v", extractErr) - return - } - t.Fatalf("failed to extract template bytes and no extraction error: %v", err) - } - - validateErr := ValidateAgentDefinition(templateBytes) - if validateErr == nil { - t.Fatalf("expected validation error containing %q, got nil (extractErr=%v)", - tc.wantErrSubst, extractErr) - } - if !strings.Contains(validateErr.Error(), tc.wantErrSubst) { - t.Fatalf("expected error containing %q, got %q", tc.wantErrSubst, validateErr.Error()) - } - }) - } -} - -// TestFixtures_SampleAgents is a regression test that ensures the sample agent -// YAML files in tests/samples/ continue to parse correctly. -func TestFixtures_SampleAgents(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - file string - wantName string - }{ - { - name: "declarativeNoTools sample", - file: filepath.Join("..", "..", "..", "..", "tests", "samples", "declarativeNoTools", "agent.yaml"), - wantName: "Learn French Agent", - }, - { - name: "githubMcpAgent sample", - file: filepath.Join("..", "..", "..", "..", "tests", "samples", "githubMcpAgent", "agent.yaml"), - wantName: "github-agent", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data, err := os.ReadFile(tc.file) - if err != nil { - t.Fatalf("failed to read sample %s: %v", tc.file, err) - } - - // Both samples are prompt agents, so ExtractAgentDefinition returns - // "prompt agents not currently supported". Validate structure instead. - _, extractErr := ExtractAgentDefinition(data) - if extractErr == nil { - t.Fatal("expected error for prompt agent sample, got nil") - } - if !strings.Contains(extractErr.Error(), "prompt agents not currently supported") { - t.Fatalf("unexpected error: %v", extractErr) - } - - // Validate that the YAML structure is well-formed by unmarshaling - // the template section into the typed structs. - templateBytes, err := extractTemplateBytes(data) - if err != nil { - t.Fatalf("failed to extract template bytes: %v", err) - } - - var agentDef AgentDefinition - if err := yaml.Unmarshal(templateBytes, &agentDef); err != nil { - t.Fatalf("failed to unmarshal AgentDefinition: %v", err) - } - if agentDef.Name != tc.wantName { - t.Errorf("name: got %q, want %q", agentDef.Name, tc.wantName) - } - if agentDef.Kind != AgentKindPrompt { - t.Errorf("kind: got %q, want %q", agentDef.Kind, AgentKindPrompt) - } - - // Also confirm the model is present for these prompt agents. - var promptAgent PromptAgent - if err := yaml.Unmarshal(templateBytes, &promptAgent); err != nil { - t.Fatalf("failed to unmarshal PromptAgent: %v", err) - } - if promptAgent.Model.Id == "" { - t.Error("expected non-empty model.id in sample agent") - } - }) - } -} - -// extractTemplateBytes reads YAML content with a top-level "template" field -// and returns the marshaled bytes of just the template section. -func extractTemplateBytes(manifestYaml []byte) ([]byte, error) { - var generic map[string]any - if err := yaml.Unmarshal(manifestYaml, &generic); err != nil { - return nil, err - } - - templateVal, ok := generic["template"] - if !ok || templateVal == nil { - return nil, os.ErrNotExist - } - - templateMap, ok := templateVal.(map[string]any) - if !ok { - return nil, os.ErrInvalid - } - - return yaml.Marshal(templateMap) -} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go deleted file mode 100644 index 000e6083839..00000000000 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers_test.go +++ /dev/null @@ -1,863 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package registry_api - -import ( - "strings" - "testing" - - "azureaiagent/internal/pkg/agents/agent_api" - "azureaiagent/internal/pkg/agents/agent_yaml" -) - -// ptr is a generic helper that returns a pointer to the given value. -// --------------------------------------------------------------------------- -// ConvertToolToYaml -// --------------------------------------------------------------------------- - -func TestConvertToolToYaml(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input any - wantErr bool - errSubstr string - validate func(t *testing.T, got any) - }{ - { - name: "nil tool returns error", - input: nil, - wantErr: true, - errSubstr: "tool cannot be nil", - }, - { - name: "unsupported type returns error", - input: "not-a-tool", - wantErr: true, - errSubstr: "unsupported tool type", - }, - { - name: "FunctionTool", - input: agent_api.FunctionTool{ - Tool: agent_api.Tool{Type: "function"}, - Name: "my_func", - Description: new("a helper function"), - Parameters: nil, - Strict: new(true), - }, - validate: func(t *testing.T, got any) { - ft, ok := got.(agent_yaml.FunctionTool) - if !ok { - t.Fatalf("expected agent_yaml.FunctionTool, got %T", got) - } - if ft.Tool.Kind != agent_yaml.ToolKindFunction { - t.Errorf("Kind = %q, want %q", ft.Tool.Kind, agent_yaml.ToolKindFunction) - } - if ft.Tool.Name != "my_func" { - t.Errorf("Name = %q, want %q", ft.Tool.Name, "my_func") - } - if ft.Tool.Description == nil || *ft.Tool.Description != "a helper function" { - t.Errorf("Description mismatch") - } - if ft.Strict == nil || *ft.Strict != true { - t.Errorf("Strict = %v, want true", ft.Strict) - } - }, - }, - { - name: "WebSearchPreviewTool with options", - input: agent_api.WebSearchPreviewTool{ - Tool: agent_api.Tool{Type: "web_search_preview"}, - UserLocation: &agent_api.Location{Type: "approximate"}, - SearchContextSize: new("medium"), - }, - validate: func(t *testing.T, got any) { - ws, ok := got.(agent_yaml.WebSearchTool) - if !ok { - t.Fatalf("expected agent_yaml.WebSearchTool, got %T", got) - } - if ws.Tool.Kind != agent_yaml.ToolKindWebSearch { - t.Errorf("Kind = %q, want %q", ws.Tool.Kind, agent_yaml.ToolKindWebSearch) - } - if ws.Tool.Name != "web_search_preview" { - t.Errorf("Name = %q, want %q", ws.Tool.Name, "web_search_preview") - } - if ws.Options == nil { - t.Fatal("Options is nil") - } - if _, exists := ws.Options["userLocation"]; !exists { - t.Error("expected userLocation in Options") - } - if ws.Options["searchContextSize"] != "medium" { - t.Errorf("searchContextSize = %v, want %q", ws.Options["searchContextSize"], "medium") - } - }, - }, - { - name: "WebSearchPreviewTool without options", - input: agent_api.WebSearchPreviewTool{ - Tool: agent_api.Tool{Type: "web_search_preview"}, - }, - validate: func(t *testing.T, got any) { - ws, ok := got.(agent_yaml.WebSearchTool) - if !ok { - t.Fatalf("expected agent_yaml.WebSearchTool, got %T", got) - } - // Options map is always created but should be empty - if len(ws.Options) != 0 { - t.Errorf("expected empty Options, got %v", ws.Options) - } - }, - }, - { - name: "BingGroundingAgentTool", - input: agent_api.BingGroundingAgentTool{ - Tool: agent_api.Tool{Type: "bing_grounding"}, - BingGrounding: agent_api.BingGroundingSearchToolParameters{ - ProjectConnections: agent_api.ToolProjectConnectionList{ - ProjectConnections: []agent_api.ToolProjectConnection{ - {ID: "conn-1"}, - }, - }, - }, - }, - validate: func(t *testing.T, got any) { - bg, ok := got.(agent_yaml.BingGroundingTool) - if !ok { - t.Fatalf("expected agent_yaml.BingGroundingTool, got %T", got) - } - if bg.Tool.Kind != agent_yaml.ToolKindBingGrounding { - t.Errorf("Kind = %q, want %q", bg.Tool.Kind, agent_yaml.ToolKindBingGrounding) - } - if bg.Tool.Name != "bing_grounding" { - t.Errorf("Name = %q, want %q", bg.Tool.Name, "bing_grounding") - } - if bg.Options == nil { - t.Fatal("Options is nil") - } - if _, exists := bg.Options["bingGrounding"]; !exists { - t.Error("expected bingGrounding in Options") - } - }, - }, - { - name: "FileSearchTool with ranking options", - input: agent_api.FileSearchTool{ - Tool: agent_api.Tool{Type: "file_search"}, - VectorStoreIds: []string{"vs-1", "vs-2"}, - MaxNumResults: new(int32(10)), - RankingOptions: &agent_api.RankingOptions{ - Ranker: new("auto"), - ScoreThreshold: new(float32(0.5)), - }, - }, - validate: func(t *testing.T, got any) { - fs, ok := got.(agent_yaml.FileSearchTool) - if !ok { - t.Fatalf("expected agent_yaml.FileSearchTool, got %T", got) - } - if fs.Tool.Kind != agent_yaml.ToolKindFileSearch { - t.Errorf("Kind = %q, want %q", fs.Tool.Kind, agent_yaml.ToolKindFileSearch) - } - if len(fs.VectorStoreIds) != 2 || fs.VectorStoreIds[0] != "vs-1" { - t.Errorf("VectorStoreIds = %v, want [vs-1 vs-2]", fs.VectorStoreIds) - } - if fs.MaximumResultCount == nil || *fs.MaximumResultCount != 10 { - t.Errorf("MaximumResultCount = %v, want 10", fs.MaximumResultCount) - } - if fs.Ranker == nil || *fs.Ranker != "auto" { - t.Errorf("Ranker = %v, want auto", fs.Ranker) - } - if fs.ScoreThreshold == nil || *fs.ScoreThreshold != float64(float32(0.5)) { - t.Errorf("ScoreThreshold = %v, want 0.5", fs.ScoreThreshold) - } - }, - }, - { - name: "FileSearchTool without ranking options", - input: agent_api.FileSearchTool{ - Tool: agent_api.Tool{Type: "file_search"}, - VectorStoreIds: []string{"vs-1"}, - }, - validate: func(t *testing.T, got any) { - fs, ok := got.(agent_yaml.FileSearchTool) - if !ok { - t.Fatalf("expected agent_yaml.FileSearchTool, got %T", got) - } - if fs.Ranker != nil { - t.Errorf("Ranker = %v, want nil", fs.Ranker) - } - if fs.ScoreThreshold != nil { - t.Errorf("ScoreThreshold = %v, want nil", fs.ScoreThreshold) - } - if fs.MaximumResultCount != nil { - t.Errorf("MaximumResultCount = %v, want nil", fs.MaximumResultCount) - } - }, - }, - { - name: "MCPTool with all fields", - input: agent_api.MCPTool{ - Tool: agent_api.Tool{Type: "mcp"}, - ServerLabel: "my-server", - ServerURL: "https://example.com", - Headers: map[string]string{"x-key": "val"}, - ProjectConnectionID: new("conn-1"), - }, - validate: func(t *testing.T, got any) { - mcp, ok := got.(agent_yaml.McpTool) - if !ok { - t.Fatalf("expected agent_yaml.McpTool, got %T", got) - } - if mcp.Tool.Kind != agent_yaml.ToolKindMcp { - t.Errorf("Kind = %q, want %q", mcp.Tool.Kind, agent_yaml.ToolKindMcp) - } - if mcp.ServerName != "my-server" { - t.Errorf("ServerName = %q, want %q", mcp.ServerName, "my-server") - } - if mcp.Options["serverUrl"] != "https://example.com" { - t.Errorf("serverUrl = %v, want %q", mcp.Options["serverUrl"], "https://example.com") - } - if mcp.Options["projectConnectionId"] != "conn-1" { - t.Errorf("projectConnectionId = %v, want %q", mcp.Options["projectConnectionId"], "conn-1") - } - headers, ok := mcp.Options["headers"].(map[string]string) - if !ok { - t.Fatalf("expected headers map[string]string, got %T", mcp.Options["headers"]) - } - if headers["x-key"] != "val" { - t.Errorf("header x-key = %q, want %q", headers["x-key"], "val") - } - }, - }, - { - name: "MCPTool minimal", - input: agent_api.MCPTool{ - Tool: agent_api.Tool{Type: "mcp"}, - ServerLabel: "minimal-server", - }, - validate: func(t *testing.T, got any) { - mcp, ok := got.(agent_yaml.McpTool) - if !ok { - t.Fatalf("expected agent_yaml.McpTool, got %T", got) - } - if mcp.ServerName != "minimal-server" { - t.Errorf("ServerName = %q, want %q", mcp.ServerName, "minimal-server") - } - // serverUrl should not appear when empty string - if _, exists := mcp.Options["serverUrl"]; exists { - t.Error("expected serverUrl to be absent for empty ServerURL") - } - if _, exists := mcp.Options["headers"]; exists { - t.Error("expected headers to be absent when nil") - } - if _, exists := mcp.Options["projectConnectionId"]; exists { - t.Error("expected projectConnectionId to be absent when nil") - } - }, - }, - { - name: "OpenApiAgentTool", - input: agent_api.OpenApiAgentTool{ - Tool: agent_api.Tool{Type: "openapi"}, - OpenAPI: agent_api.OpenApiFunctionDefinition{ - Name: "weather-api", - Description: new("Weather lookup"), - }, - }, - validate: func(t *testing.T, got any) { - oa, ok := got.(agent_yaml.OpenApiTool) - if !ok { - t.Fatalf("expected agent_yaml.OpenApiTool, got %T", got) - } - if oa.Tool.Kind != agent_yaml.ToolKindOpenApi { - t.Errorf("Kind = %q, want %q", oa.Tool.Kind, agent_yaml.ToolKindOpenApi) - } - if oa.Tool.Name != "openapi" { - t.Errorf("Name = %q, want %q", oa.Tool.Name, "openapi") - } - if _, exists := oa.Options["openapi"]; !exists { - t.Error("expected openapi in Options") - } - }, - }, - { - name: "CodeInterpreterTool with container", - input: agent_api.CodeInterpreterTool{ - Tool: agent_api.Tool{Type: "code_interpreter"}, - Container: "container-id-123", - }, - validate: func(t *testing.T, got any) { - ci, ok := got.(agent_yaml.CodeInterpreterTool) - if !ok { - t.Fatalf("expected agent_yaml.CodeInterpreterTool, got %T", got) - } - if ci.Tool.Kind != agent_yaml.ToolKindCodeInterpreter { - t.Errorf("Kind = %q, want %q", ci.Tool.Kind, agent_yaml.ToolKindCodeInterpreter) - } - if ci.Tool.Name != "code_interpreter" { - t.Errorf("Name = %q, want %q", ci.Tool.Name, "code_interpreter") - } - if ci.Options["container"] != "container-id-123" { - t.Errorf("container = %v, want %q", ci.Options["container"], "container-id-123") - } - }, - }, - { - name: "CodeInterpreterTool without container", - input: agent_api.CodeInterpreterTool{ - Tool: agent_api.Tool{Type: "code_interpreter"}, - Container: nil, - }, - validate: func(t *testing.T, got any) { - ci, ok := got.(agent_yaml.CodeInterpreterTool) - if !ok { - t.Fatalf("expected agent_yaml.CodeInterpreterTool, got %T", got) - } - if _, exists := ci.Options["container"]; exists { - t.Error("expected container to be absent when nil") - } - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := ConvertToolToYaml(tc.input) - if tc.wantErr { - if err == nil { - t.Fatal("expected error, got nil") - } - if tc.errSubstr != "" && !strings.Contains(err.Error(), tc.errSubstr) { - t.Errorf("error %q does not contain %q", err, tc.errSubstr) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tc.validate != nil { - tc.validate(t, got) - } - }) - } -} - -// --------------------------------------------------------------------------- -// ConvertAgentDefinition -// --------------------------------------------------------------------------- - -func TestConvertAgentDefinition(t *testing.T) { - t.Parallel() - - t.Run("empty tools", func(t *testing.T) { - t.Parallel() - def := agent_api.PromptAgentDefinition{ - Model: "gpt-4o", - Instructions: new("Be helpful"), - } - - got, err := ConvertAgentDefinition(def) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.AgentDefinition.Kind != agent_yaml.AgentKindPrompt { - t.Errorf("Kind = %q, want %q", got.AgentDefinition.Kind, agent_yaml.AgentKindPrompt) - } - if got.Model.Id != "gpt-4o" { - t.Errorf("Model.Id = %q, want %q", got.Model.Id, "gpt-4o") - } - if got.Instructions == nil || *got.Instructions != "Be helpful" { - t.Errorf("Instructions mismatch") - } - if got.Tools == nil { - t.Fatal("Tools should not be nil") - } - if len(*got.Tools) != 0 { - t.Errorf("expected 0 tools, got %d", len(*got.Tools)) - } - }) - - t.Run("with tools", func(t *testing.T) { - t.Parallel() - def := agent_api.PromptAgentDefinition{ - Model: "gpt-4o-mini", - Instructions: new("Do things"), - Tools: []any{ - agent_api.FunctionTool{ - Tool: agent_api.Tool{Type: "function"}, - Name: "fn1", - }, - agent_api.CodeInterpreterTool{ - Tool: agent_api.Tool{Type: "code_interpreter"}, - }, - }, - } - - got, err := ConvertAgentDefinition(def) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(*got.Tools) != 2 { - t.Fatalf("expected 2 tools, got %d", len(*got.Tools)) - } - // Verify first tool is FunctionTool - if _, ok := (*got.Tools)[0].(agent_yaml.FunctionTool); !ok { - t.Errorf("expected first tool to be FunctionTool, got %T", (*got.Tools)[0]) - } - // Verify second tool is CodeInterpreterTool - if _, ok := (*got.Tools)[1].(agent_yaml.CodeInterpreterTool); !ok { - t.Errorf("expected second tool to be CodeInterpreterTool, got %T", (*got.Tools)[1]) - } - }) - - t.Run("unsupported tool propagates error", func(t *testing.T) { - t.Parallel() - def := agent_api.PromptAgentDefinition{ - Model: "gpt-4o", - Tools: []any{"bad-tool"}, - } - _, err := ConvertAgentDefinition(def) - if err == nil { - t.Fatal("expected error for unsupported tool") - } - if !strings.Contains(err.Error(), "unsupported tool type") { - t.Errorf("error %q does not mention unsupported tool type", err) - } - }) - - t.Run("nil instructions", func(t *testing.T) { - t.Parallel() - def := agent_api.PromptAgentDefinition{ - Model: "gpt-4o", - } - got, err := ConvertAgentDefinition(def) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.Instructions != nil { - t.Errorf("Instructions = %v, want nil", got.Instructions) - } - }) -} - -// --------------------------------------------------------------------------- -// ConvertParameters -// --------------------------------------------------------------------------- - -func TestConvertParameters(t *testing.T) { - t.Parallel() - - t.Run("nil parameters returns nil", func(t *testing.T) { - t.Parallel() - got, err := ConvertParameters(nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != nil { - t.Errorf("expected nil, got %+v", got) - } - }) - - t.Run("empty parameters returns nil", func(t *testing.T) { - t.Parallel() - got, err := ConvertParameters(map[string]OpenApiParameter{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got != nil { - t.Errorf("expected nil, got %+v", got) - } - }) - - t.Run("single parameter with schema and enum", func(t *testing.T) { - t.Parallel() - params := map[string]OpenApiParameter{ - "region": { - Description: "Azure region", - Required: true, - Schema: &OpenApiSchema{ - Type: "string", - Enum: []any{"eastus", "westus"}, - }, - }, - } - - got, err := ConvertParameters(params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got == nil { - t.Fatal("expected non-nil PropertySchema") - } - if len(got.Properties) != 1 { - t.Fatalf("expected 1 property, got %d", len(got.Properties)) - } - p := got.Properties[0] - if p.Name != "region" { - t.Errorf("Name = %q, want %q", p.Name, "region") - } - if p.Kind != "string" { - t.Errorf("Kind = %q, want %q", p.Kind, "string") - } - if p.Description == nil || *p.Description != "Azure region" { - t.Errorf("Description mismatch") - } - if p.Required == nil || *p.Required != true { - t.Errorf("Required = %v, want true", p.Required) - } - if p.EnumValues == nil || len(*p.EnumValues) != 2 { - t.Fatalf("expected 2 enum values, got %v", p.EnumValues) - } - if (*p.EnumValues)[0] != "eastus" || (*p.EnumValues)[1] != "westus" { - t.Errorf("EnumValues = %v, want [eastus westus]", *p.EnumValues) - } - }) - - t.Run("parameter without schema defaults to string kind", func(t *testing.T) { - t.Parallel() - params := map[string]OpenApiParameter{ - "name": { - Description: "Agent name", - Required: false, - }, - } - - got, err := ConvertParameters(params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got == nil { - t.Fatal("expected non-nil PropertySchema") - } - p := got.Properties[0] - if p.Kind != "string" { - t.Errorf("Kind = %q, want %q (default)", p.Kind, "string") - } - if p.EnumValues != nil { - t.Errorf("EnumValues = %v, want nil", p.EnumValues) - } - }) - - t.Run("parameter with example sets default", func(t *testing.T) { - t.Parallel() - params := map[string]OpenApiParameter{ - "timeout": { - Description: "Timeout in seconds", - Example: 30, - Schema: &OpenApiSchema{Type: "integer"}, - }, - } - - got, err := ConvertParameters(params) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - p := got.Properties[0] - if p.Default == nil { - t.Fatal("expected Default to be set from Example") - } - if *p.Default != 30 { - t.Errorf("Default = %v, want 30", *p.Default) - } - }) -} - -// --------------------------------------------------------------------------- -// MergeManifestIntoAgentDefinition -// --------------------------------------------------------------------------- - -func TestMergeManifestIntoAgentDefinition(t *testing.T) { - t.Parallel() - - t.Run("fills empty name from manifest", func(t *testing.T) { - t.Parallel() - manifest := &Manifest{ - Name: "manifest-agent", - DisplayName: "Manifest Agent", - Description: "A description", - } - agentDef := &agent_yaml.AgentDefinition{ - Kind: agent_yaml.AgentKindPrompt, - } - - result := MergeManifestIntoAgentDefinition(manifest, agentDef) - - if result.Name != "manifest-agent" { - t.Errorf("Name = %q, want %q", result.Name, "manifest-agent") - } - }) - - t.Run("does not overwrite existing name", func(t *testing.T) { - t.Parallel() - manifest := &Manifest{ - Name: "manifest-name", - } - agentDef := &agent_yaml.AgentDefinition{ - Kind: agent_yaml.AgentKindPrompt, - Name: "existing-name", - } - - result := MergeManifestIntoAgentDefinition(manifest, agentDef) - - if result.Name != "existing-name" { - t.Errorf("Name = %q, want %q", result.Name, "existing-name") - } - }) - - t.Run("does not modify original agent definition", func(t *testing.T) { - t.Parallel() - manifest := &Manifest{ - Name: "new-name", - } - agentDef := &agent_yaml.AgentDefinition{ - Kind: agent_yaml.AgentKindPrompt, - } - - _ = MergeManifestIntoAgentDefinition(manifest, agentDef) - - if agentDef.Name != "" { - t.Errorf("original AgentDefinition.Name was modified to %q", agentDef.Name) - } - }) - - t.Run("preserves kind when already set", func(t *testing.T) { - t.Parallel() - manifest := &Manifest{ - Name: "test", - } - agentDef := &agent_yaml.AgentDefinition{ - Kind: agent_yaml.AgentKindPrompt, - Name: "keep", - } - - result := MergeManifestIntoAgentDefinition(manifest, agentDef) - - if result.Kind != agent_yaml.AgentKindPrompt { - t.Errorf("Kind = %q, want %q", result.Kind, agent_yaml.AgentKindPrompt) - } - }) -} - -// --------------------------------------------------------------------------- -// injectParameterValues -// --------------------------------------------------------------------------- - -func TestInjectParameterValues(t *testing.T) { - t.Parallel() - - t.Run("replaces {{param}} style", func(t *testing.T) { - t.Parallel() - template := "Hello {{name}}, welcome to {{place}}!" - values := ParameterValues{ - "name": "Alice", - "place": "Wonderland", - } - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - want := "Hello Alice, welcome to Wonderland!" - if string(got) != want { - t.Errorf("got %q, want %q", string(got), want) - } - }) - - t.Run("replaces {{ param }} style with spaces", func(t *testing.T) { - t.Parallel() - template := "Value is {{ apiKey }}" - values := ParameterValues{ - "apiKey": "secret-123", - } - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - want := "Value is secret-123" - if string(got) != want { - t.Errorf("got %q, want %q", string(got), want) - } - }) - - t.Run("replaces both styles in same template", func(t *testing.T) { - t.Parallel() - template := "{{key1}} and {{ key1 }}" - values := ParameterValues{ - "key1": "replaced", - } - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - want := "replaced and replaced" - if string(got) != want { - t.Errorf("got %q, want %q", string(got), want) - } - }) - - t.Run("no placeholders returns unchanged", func(t *testing.T) { - t.Parallel() - template := "no placeholders here" - values := ParameterValues{"key": "val"} - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if string(got) != template { - t.Errorf("got %q, want %q", string(got), template) - } - }) - - t.Run("empty parameter values returns unchanged", func(t *testing.T) { - t.Parallel() - template := "Hello {{name}}" - values := ParameterValues{} - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // Unresolved placeholders remain but no error - if string(got) != template { - t.Errorf("got %q, want %q", string(got), template) - } - }) - - t.Run("non-string value is converted via Sprintf", func(t *testing.T) { - t.Parallel() - template := "count={{count}}" - values := ParameterValues{"count": 42} - - got, err := injectParameterValues(template, values) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - want := "count=42" - if string(got) != want { - t.Errorf("got %q, want %q", string(got), want) - } - }) - - t.Run("empty template returns empty", func(t *testing.T) { - t.Parallel() - got, err := injectParameterValues("", ParameterValues{"k": "v"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if string(got) != "" { - t.Errorf("got %q, want empty", string(got)) - } - }) -} - -// --------------------------------------------------------------------------- -// convertFloat32ToFloat64 -// --------------------------------------------------------------------------- - -func TestConvertFloat32ToFloat64(t *testing.T) { - t.Parallel() - - t.Run("nil returns nil", func(t *testing.T) { - t.Parallel() - got := convertFloat32ToFloat64(nil) - if got != nil { - t.Errorf("expected nil, got %v", *got) - } - }) - - t.Run("converts value", func(t *testing.T) { - t.Parallel() - f32 := float32(0.75) - got := convertFloat32ToFloat64(&f32) - if got == nil { - t.Fatal("expected non-nil") - } - if *got != float64(f32) { - t.Errorf("got %v, want %v", *got, float64(f32)) - } - }) - - t.Run("zero value", func(t *testing.T) { - t.Parallel() - f32 := float32(0) - got := convertFloat32ToFloat64(&f32) - if got == nil { - t.Fatal("expected non-nil") - } - if *got != 0 { - t.Errorf("got %v, want 0", *got) - } - }) -} - -// --------------------------------------------------------------------------- -// convertInt32ToInt -// --------------------------------------------------------------------------- - -func TestConvertInt32ToInt(t *testing.T) { - t.Parallel() - - t.Run("nil returns nil", func(t *testing.T) { - t.Parallel() - got := convertInt32ToInt(nil) - if got != nil { - t.Errorf("expected nil, got %v", *got) - } - }) - - t.Run("converts value", func(t *testing.T) { - t.Parallel() - i32 := int32(42) - got := convertInt32ToInt(&i32) - if got == nil { - t.Fatal("expected non-nil") - } - if *got != 42 { - t.Errorf("got %d, want 42", *got) - } - }) - - t.Run("zero value", func(t *testing.T) { - t.Parallel() - i32 := int32(0) - got := convertInt32ToInt(&i32) - if got == nil { - t.Fatal("expected non-nil") - } - if *got != 0 { - t.Errorf("got %d, want 0", *got) - } - }) -} - -// --------------------------------------------------------------------------- -// convertToPropertySchema -// --------------------------------------------------------------------------- - -func TestConvertToPropertySchema(t *testing.T) { - t.Parallel() - - t.Run("nil input returns empty properties", func(t *testing.T) { - t.Parallel() - got := convertToPropertySchema(nil) - if len(got.Properties) != 0 { - t.Errorf("expected 0 properties, got %d", len(got.Properties)) - } - }) - - t.Run("non-nil input returns empty properties", func(t *testing.T) { - t.Parallel() - // Current implementation is a placeholder that always returns empty properties - got := convertToPropertySchema(map[string]any{"key": "value"}) - if len(got.Properties) != 0 { - t.Errorf("expected 0 properties, got %d", len(got.Properties)) - } - }) -} From a93e25b325bc23761038e93bcaab00dfafc6dc69 Mon Sep 17 00:00:00 2001 From: trangevi Date: Fri, 10 Apr 2026 16:38:29 -0700 Subject: [PATCH 13/14] PR comments part 1 Signed-off-by: trangevi --- .../azure.ai.agents/internal/cmd/init.go | 22 +- .../azure.ai.agents/internal/cmd/init_test.go | 20 +- .../pkg/agents/agent_yaml/parse_test.go | 24 +- .../pkg/azure/foundry_toolsets_client_test.go | 288 ++++++++++++++++++ .../internal/project/service_target_agent.go | 15 - 5 files changed, 321 insertions(+), 48 deletions(-) create mode 100644 cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index b0a99a75c84..100cdf5d60c 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" "time" @@ -1722,22 +1723,17 @@ func extractToolboxAndConnectionConfigs( ) } - // Manifest uses "id" for tool kind; API uses "type" - toolId, _ := toolMap["id"].(string) + // Manifest and API both use "type" for tool kind + toolType, _ := toolMap["type"].(string) target, _ := toolMap["target"].(string) if target == "" { // No target — either a built-in tool or a pre-configured tool - // that already has project_connection_id. Convert "id" → "type" - // and pass through all existing fields. + // that already has project_connection_id. Pass through as-is. result := make(map[string]any, len(toolMap)) for k, v := range toolMap { result[k] = v } - if _, hasType := result["type"]; !hasType { - result["type"] = toolId - delete(result, "id") - } tools = append(tools, result) continue } @@ -1749,7 +1745,7 @@ func extractToolboxAndConnectionConfigs( connName := toolName if connName == "" { - connName = tbResource.Name + "-" + toolId + connName = tbResource.Name + "-" + toolType } conn := project.ToolConnection{ @@ -1781,7 +1777,7 @@ func extractToolboxAndConnectionConfigs( // Toolbox tool entry is minimal — deploy enriches from connection tool := map[string]any{ - "type": toolId, + "type": toolType, "project_connection_id": connName, } tools = append(tools, tool) @@ -1799,9 +1795,13 @@ func extractToolboxAndConnectionConfigs( // credentialEnvVarName builds a deterministic env var name for a connection // credential key, e.g. ("github-copilot", "clientSecret") → "TOOL_GITHUB_COPILOT_CLIENTSECRET". +// All non-alphanumeric characters are replaced with underscores and consecutive +// underscores are collapsed to produce a valid [A-Z0-9_]+ environment variable name. +var nonAlphanumRe = regexp.MustCompile(`[^A-Z0-9]+`) + func credentialEnvVarName(connName, key string) string { s := "TOOL_" + strings.ToUpper(connName) + "_" + strings.ToUpper(key) - return strings.ReplaceAll(s, "-", "_") + return nonAlphanumRe.ReplaceAllString(s, "_") } // injectToolboxEnvVarsIntoDefinition adds TOOLBOX_{NAME}_MCP_ENDPOINT entries diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go index d7e723dc323..6f78c67070d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go @@ -411,11 +411,11 @@ func TestExtractToolboxAndConnectionConfigs_TypedTools(t *testing.T) { Tools: []any{ map[string]any{ // Built-in tool — no connection - "id": "bing_grounding", + "type": "bing_grounding", }, map[string]any{ // External tool with name — connection name from Name field - "id": "mcp", + "type": "mcp", "name": "github-copilot", "target": "https://api.githubcopilot.com/mcp", "authType": "OAuth2", @@ -523,7 +523,7 @@ func TestExtractToolboxAndConnectionConfigs_RawToolsFallback(t *testing.T) { Description: "Raw tools", Tools: []any{ map[string]any{ - "id": "mcp", + "type": "mcp", "name": "existing", "project_connection_id": "existing-conn", }, @@ -583,14 +583,14 @@ func TestExtractToolboxAndConnectionConfigs_CustomKeysCredentials(t *testing.T) }, Tools: []any{ map[string]any{ - "id": "mcp", + "type": "mcp", "name": "custom-api", "target": "https://example.com/mcp", "authType": "CustomKeys", "credentials": map[string]any{"key": "my-api-key"}, }, map[string]any{ - "id": "mcp", + "type": "mcp", "name": "oauth-tool", "target": "https://example.com/oauth", "authType": "OAuth2", @@ -657,7 +657,7 @@ func TestInjectToolboxEnvVarsIntoDefinition_AddsEnvVars(t *testing.T) { Kind: agent_yaml.ResourceKindToolbox, }, Tools: []any{ - map[string]any{"id": "bing_grounding"}, + map[string]any{"type": "bing_grounding"}, }, }, }, @@ -709,7 +709,7 @@ func TestInjectToolboxEnvVarsIntoDefinition_SkipsExisting(t *testing.T) { Kind: agent_yaml.ResourceKindToolbox, }, Tools: []any{ - map[string]any{"id": "bing_grounding"}, + map[string]any{"type": "bing_grounding"}, }, }, }, @@ -745,11 +745,11 @@ func TestInjectToolboxEnvVarsIntoDefinition_MultipleToolboxes(t *testing.T) { Resources: []any{ agent_yaml.ToolboxResource{ Resource: agent_yaml.Resource{Name: "search-tools", Kind: agent_yaml.ResourceKindToolbox}, - Tools: []any{map[string]any{"id": "bing_grounding"}}, + Tools: []any{map[string]any{"type": "bing_grounding"}}, }, agent_yaml.ToolboxResource{ Resource: agent_yaml.Resource{Name: "github-tools", Kind: agent_yaml.ResourceKindToolbox}, - Tools: []any{map[string]any{"id": "mcp", "target": "https://example.com"}}, + Tools: []any{map[string]any{"type": "mcp", "target": "https://example.com"}}, }, }, } @@ -790,7 +790,7 @@ func TestInjectToolboxEnvVarsIntoDefinition_NoopForPromptAgent(t *testing.T) { Resources: []any{ agent_yaml.ToolboxResource{ Resource: agent_yaml.Resource{Name: "tools", Kind: agent_yaml.ResourceKindToolbox}, - Tools: []any{map[string]any{"id": "bing_grounding"}}, + Tools: []any{map[string]any{"type": "bing_grounding"}}, }, }, } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index e666a1e83d3..544402252b4 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -399,8 +399,8 @@ resources: name: agent-tools description: MCP tools for documentation search tools: - - id: web_search - - id: mcp + - type: web_search + - type: mcp project_connection_id: context7 `) @@ -610,15 +610,15 @@ resources: name: platform-tools description: Platform tools with typed definitions tools: - - id: bing_grounding - - id: mcp + - type: bing_grounding + - type: mcp name: github-copilot target: https://api.githubcopilot.com/mcp authType: OAuth2 credentials: clientId: my-client-id clientSecret: my-client-secret - - id: mcp + - type: mcp name: custom-api target: https://my-api.example.com/sse authType: CustomKeys @@ -662,16 +662,16 @@ resources: } // Check built-in tool (no target/authType/name) - if tool(0)["id"] != "bing_grounding" { - t.Errorf("Expected first tool id 'bing_grounding', got '%v'", tool(0)["id"]) + if tool(0)["type"] != "bing_grounding" { + t.Errorf("Expected first tool type 'bing_grounding', got '%v'", tool(0)["type"]) } if tool(0)["target"] != nil { t.Errorf("Expected no target for built-in tool, got '%v'", tool(0)["target"]) } // Check MCP tool with name and OAuth2 - if tool(1)["id"] != "mcp" { - t.Errorf("Expected second tool id 'mcp', got '%v'", tool(1)["id"]) + if tool(1)["type"] != "mcp" { + t.Errorf("Expected second tool type 'mcp', got '%v'", tool(1)["type"]) } if tool(1)["name"] != "github-copilot" { t.Errorf("Expected second tool name 'github-copilot', got '%v'", tool(1)["name"]) @@ -688,8 +688,8 @@ resources: } // Check MCP tool with CustomKeys - if tool(2)["id"] != "mcp" { - t.Errorf("Expected third tool id 'mcp', got '%v'", tool(2)["id"]) + if tool(2)["type"] != "mcp" { + t.Errorf("Expected third tool type 'mcp', got '%v'", tool(2)["type"]) } if tool(2)["name"] != "custom-api" { t.Errorf("Expected third tool name 'custom-api', got '%v'", tool(2)["name"]) @@ -717,7 +717,7 @@ resources: - kind: toolbox name: tools tools: - - id: mcp + - type: mcp name: github target: https://api.githubcopilot.com/mcp authType: OAuth2 diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go new file mode 100644 index 00000000000..d2818625f70 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/require" +) + +// roundTripFunc is a test helper that captures HTTP requests. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// newTestPipeline creates an Azure SDK pipeline backed by a custom round-tripper. +func newTestPipeline(fn roundTripFunc) runtime.Pipeline { + return runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{ + Transport: &http.Client{Transport: fn}, + }, + ) +} + +// newTestToolboxClient creates a FoundryToolboxClient backed by a custom +// HTTP round-tripper so we can inspect requests and control responses +// without touching the network. +func newTestToolboxClient( + endpoint string, + fn roundTripFunc, +) *FoundryToolboxClient { + return &FoundryToolboxClient{ + endpoint: endpoint, + pipeline: newTestPipeline(fn), + } +} + +func TestCreateToolboxVersion_URLConstruction(t *testing.T) { + tests := []struct { + name string + endpoint string + toolboxName string + wantPath string + wantQuery string + }{ + { + name: "simple name", + endpoint: "https://example.com", + toolboxName: "my-toolbox", + wantPath: "/toolboxes/my-toolbox/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + { + name: "name with special chars is escaped", + endpoint: "https://example.com", + toolboxName: "my toolbox/v2", + wantPath: "/toolboxes/my%20toolbox%2Fv2/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + { + name: "endpoint with trailing slash", + endpoint: "https://example.com/", + toolboxName: "tools", + wantPath: "//toolboxes/tools/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient(tt.endpoint, func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), tt.toolboxName, &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, http.MethodPost, capturedReq.Method) + require.Equal(t, tt.wantPath, capturedReq.URL.EscapedPath()) + require.Equal(t, tt.wantQuery, capturedReq.URL.RawQuery) + }) + } +} + +func TestCreateToolboxVersion_RequiredHeaders(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), "test-toolbox", &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, toolboxesFeatureHeader, capturedReq.Header.Get("Foundry-Features")) + require.Equal(t, "application/json", capturedReq.Header.Get("Content-Type")) +} + +func TestCreateToolboxVersion_ErrorStatusCodes(t *testing.T) { + tests := []struct { + name string + statusCode int + wantErr bool + }{ + {"200 OK", http.StatusOK, false}, + {"400 Bad Request", http.StatusBadRequest, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"409 Conflict", http.StatusConflict, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), "test", &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetToolbox_URLConstruction(t *testing.T) { + tests := []struct { + name string + toolboxName string + wantPath string + }{ + { + name: "simple name", + toolboxName: "my-toolbox", + wantPath: "/toolboxes/my-toolbox", + }, + { + name: "name needing escape", + toolboxName: "test/box", + wantPath: "/toolboxes/" + url.PathEscape("test/box"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","default_version":"v1"}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.GetToolbox(t.Context(), tt.toolboxName) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, http.MethodGet, capturedReq.Method) + require.Equal(t, tt.wantPath, capturedReq.URL.EscapedPath()) + require.Equal(t, toolboxesFeatureHeader, capturedReq.Header.Get("Foundry-Features")) + }) + } +} + +func TestDeleteToolbox_URLAndStatusCodes(t *testing.T) { + tests := []struct { + name string + toolboxName string + statusCode int + wantErr bool + }{ + {"200 OK", "my-toolbox", http.StatusOK, false}, + {"204 No Content", "my-toolbox", http.StatusNoContent, false}, + {"404 Not Found", "my-toolbox", http.StatusNotFound, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + }) + + err := client.DeleteToolbox(t.Context(), tt.toolboxName) + require.NotNil(t, capturedReq) + require.Equal(t, http.MethodDelete, capturedReq.Method) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestToolboxClient_PathEscaping_Adversarial(t *testing.T) { + tests := []struct { + name string + toolboxName string + }{ + {"path traversal", "../../../etc/passwd"}, + {"slashes in name", "name/with/slashes"}, + {"backslash traversal", `..\..\evil`}, + {"URL-encoded slash", "..%2F..%2Fevil"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","default_version":"v1"}`)), + Header: make(http.Header), + }, nil + }) + + _, _ = client.GetToolbox(t.Context(), tt.toolboxName) + require.NotNil(t, capturedReq) + + // The escaped name must not introduce extra path segments + escaped := url.PathEscape(tt.toolboxName) + expectedPath := fmt.Sprintf("/toolboxes/%s", escaped) + require.Equal(t, expectedPath, capturedReq.URL.EscapedPath()) + + // No raw slashes in the toolbox name segment + escapedPath := capturedReq.URL.EscapedPath() + segments := strings.Split(strings.Trim(escapedPath, "/"), "/") + // Should be exactly: ["toolboxes", ""] + // (plus query params handled separately) + require.GreaterOrEqual(t, len(segments), 2, + "path should have at least 2 segments") + require.Equal(t, "toolboxes", segments[0]) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index a218204710f..367e7fd003d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -651,21 +651,6 @@ func (p *AgentServiceTargetProvider) deployHostedAgent( memory = foundryAgentConfig.Container.Resources.Memory } - // Auto-inject toolbox MCP endpoint env vars so hosted agents can reach their toolboxes - // without requiring users to manually add them to agent.yaml's environment_variables. - if foundryAgentConfig != nil { - projectEndpoint := strings.TrimRight(azdEnv["AZURE_AI_PROJECT_ENDPOINT"], "/") - for _, toolbox := range foundryAgentConfig.Toolboxes { - toolboxKey := p.getServiceKey(toolbox.Name) - envKey := fmt.Sprintf("TOOLBOX_%s_MCP_ENDPOINT", toolboxKey) - if _, exists := resolvedEnvVars[envKey]; !exists { - resolvedEnvVars[envKey] = fmt.Sprintf( - "%s/toolsets/%s/mcp", projectEndpoint, toolbox.Name, - ) - } - } - } - // Build options list starting with required options options := []agent_yaml.AgentBuildOption{ agent_yaml.WithImageURL(fullImageURL), From 1d3ee25c6021a9f7dee95d80265071d85cd586f0 Mon Sep 17 00:00:00 2001 From: trangevi Date: Fri, 10 Apr 2026 17:34:01 -0700 Subject: [PATCH 14/14] Update feature header Signed-off-by: trangevi --- .../internal/pkg/azure/foundry_toolsets_client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go index b18ea91356c..537e09d1a69 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go @@ -22,7 +22,7 @@ import ( const ( toolboxesApiVersion = "v1" - toolboxesFeatureHeader = "Toolsets=V1Preview" + toolboxesFeatureHeader = "Toolboxes=V1Preview" ) // FoundryToolboxClient provides methods for interacting with the Foundry Toolboxes API.