diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/files.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/files.go index 3daa92504d6..afd2471f90c 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/files.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/files.go @@ -70,13 +70,13 @@ func newFilesCommand() *cobra.Command { Hidden: !isVNextEnabled(context.Background()), Long: `Manage files in a hosted agent session. -Upload, download, list, and remove files in the session-scoped filesystem +Upload, download, list, and delete files in the session-scoped filesystem of a hosted agent. This is useful for debugging, seeding data, and agent setup. Agent details (name, endpoint) are automatically resolved from the azd environment. Use --agent-name to select a specific agent when the project has multiple azure.ai.agent services. The session ID is automatically resolved -from the last invoke session, or can be overridden with --session.`, +from the last invoke session, or can be overridden with --session-id.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { // Chain with root's PersistentPreRunE (root sets NoPrompt). // Note: cmd.Parent() would return the "files" command itself when @@ -111,7 +111,7 @@ from the last invoke session, or can be overridden with --session.`, // addFilesFlags registers the common flags on a cobra command. func addFilesFlags(cmd *cobra.Command, flags *filesFlags) { cmd.Flags().StringVarP(&flags.agentName, "agent-name", "n", "", "Agent name (matches azure.yaml service name; auto-detected when only one exists)") - cmd.Flags().StringVarP(&flags.session, "session", "s", "", "Session ID override (defaults to last invoke session)") + cmd.Flags().StringVarP(&flags.session, "session-id", "s", "", "Session ID override (defaults to last invoke session)") } // filesContext holds the resolved agent context and session for file operations. @@ -179,27 +179,38 @@ func newFilesUploadCommand() *cobra.Command { flags := &filesUploadFlags{} cmd := &cobra.Command{ - Use: "upload", + Use: "upload [file]", Short: "Upload a file to a hosted agent session.", Long: `Upload a file to a hosted agent session. Reads a local file and uploads it to the specified remote path in the session's filesystem. If --target-path is not provided, -the remote path defaults to the local file path. +the remote path defaults to the local filename. Agent details are automatically resolved from the azd environment.`, - Example: ` # Upload a file (remote path defaults to local path) - azd ai agent files upload --file ./data/input.csv + Example: ` # Upload a file (remote path defaults to filename) + azd ai agent files upload ./data/input.csv # Upload to a specific remote path - azd ai agent files upload --file ./input.csv --target-path /data/input.csv + azd ai agent files upload ./input.csv --target-path /data/input.csv - # Upload with explicit agent name and session - azd ai agent files upload --file ./input.csv --agent-name my-agent --session `, + # Upload with flags + azd ai agent files upload --file ./input.csv --agent-name my-agent`, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) setupDebugLogging(cmd.Flags()) + if len(args) > 0 && flags.file == "" { + flags.file = args[0] + } + if flags.file == "" { + return fmt.Errorf( + "file path is required as a positional argument " + + "or via --file", + ) + } + fc, err := resolveFilesContext(ctx, &flags.filesFlags) if err != nil { return err @@ -216,9 +227,8 @@ Agent details are automatically resolved from the azd environment.`, } addFilesFlags(cmd, &flags.filesFlags) - cmd.Flags().StringVarP(&flags.file, "file", "f", "", "Local file path to upload (required)") - cmd.Flags().StringVarP(&flags.targetPath, "target-path", "t", "", "Remote destination path (defaults to local file path)") - _ = cmd.MarkFlagRequired("file") + cmd.Flags().StringVarP(&flags.file, "file", "f", "", "Local file path to upload") + cmd.Flags().StringVarP(&flags.targetPath, "target-path", "t", "", "Remote destination path (defaults to local filename)") return cmd } @@ -227,7 +237,7 @@ Agent details are automatically resolved from the azd environment.`, func (a *FilesUploadAction) Run(ctx context.Context) error { remotePath := a.flags.targetPath if remotePath == "" { - remotePath = a.flags.file + remotePath = filepath.Base(a.flags.file) } //nolint:gosec // G304: file path is provided by the user via CLI flag @@ -254,7 +264,7 @@ func (a *FilesUploadAction) Run(ctx context.Context) error { return fmt.Errorf("failed to upload file: %w", err) } - fmt.Printf("Uploaded %s → %s\n", a.flags.file, remotePath) + fmt.Printf("Uploaded %s -> %s\n", a.flags.file, remotePath) return nil } @@ -277,7 +287,7 @@ func newFilesDownloadCommand() *cobra.Command { flags := &filesDownloadFlags{} cmd := &cobra.Command{ - Use: "download", + Use: "download [file]", Short: "Download a file from a hosted agent session.", Long: `Download a file from a hosted agent session. @@ -287,17 +297,28 @@ the local path defaults to the basename of the remote file. Agent details are automatically resolved from the azd environment.`, Example: ` # Download a file (local path defaults to remote filename) - azd ai agent files download --file /data/output.csv + azd ai agent files download /data/output.csv # Download to a specific local path - azd ai agent files download --file /data/output.csv --target-path ./output.csv + azd ai agent files download /data/output.csv --target-path ./output.csv - # Download with explicit session - azd ai agent files download --file /data/output.csv --session `, + # Download with flags + azd ai agent files download --file /data/output.csv --session-id `, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) setupDebugLogging(cmd.Flags()) + if len(args) > 0 && flags.file == "" { + flags.file = args[0] + } + if flags.file == "" { + return fmt.Errorf( + "file path is required as a positional argument " + + "or via --file", + ) + } + fc, err := resolveFilesContext(ctx, &flags.filesFlags) if err != nil { return err @@ -314,9 +335,8 @@ Agent details are automatically resolved from the azd environment.`, } addFilesFlags(cmd, &flags.filesFlags) - cmd.Flags().StringVarP(&flags.file, "file", "f", "", "Remote file path to download (required)") + cmd.Flags().StringVarP(&flags.file, "file", "f", "", "Remote file path to download") cmd.Flags().StringVarP(&flags.targetPath, "target-path", "t", "", "Local destination path (defaults to remote filename)") - _ = cmd.MarkFlagRequired("file") return cmd } @@ -356,7 +376,7 @@ func (a *FilesDownloadAction) Run(ctx context.Context) error { return fmt.Errorf("failed to write file: %w", err) } - fmt.Printf("Downloaded %s → %s\n", a.flags.file, targetPath) + fmt.Printf("Downloaded %s -> %s\n", a.flags.file, targetPath) return nil } @@ -397,7 +417,7 @@ Agent details are automatically resolved from the azd environment.`, azd ai agent files list /data --output table # List with explicit session - azd ai agent files list --session `, + azd ai agent files list --session-id `, Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) @@ -505,26 +525,38 @@ func newFilesRemoveCommand() *cobra.Command { var filePath string cmd := &cobra.Command{ - Use: "remove", - Short: "Remove a file or directory from a hosted agent session.", - Long: `Remove a file or directory from a hosted agent session. + Use: "delete [file]", + Aliases: []string{"remove"}, + Short: "Delete a file or directory from a hosted agent session.", + Long: `Delete a file or directory from a hosted agent session. -Removes the specified file or directory from the session's filesystem. -Use --recursive to remove directories and their contents. +Deletes the specified file or directory from the session's filesystem. +Use --recursive to delete directories and their contents. Agent details are automatically resolved from the azd environment.`, - Example: ` # Remove a file (agent auto-detected) - azd ai agent files remove --file /data/old-file.csv + Example: ` # Delete a file (agent auto-detected) + azd ai agent files delete /data/old-file.csv - # Remove a directory recursively - azd ai agent files remove --file /data/temp --recursive + # Delete a directory recursively + azd ai agent files delete /data/temp --recursive - # Remove with explicit session - azd ai agent files remove --file /data/old-file.csv --session `, + # Delete with flags + azd ai agent files delete --file /data/old-file.csv --session-id `, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) setupDebugLogging(cmd.Flags()) + if len(args) > 0 && filePath == "" { + filePath = args[0] + } + if filePath == "" { + return fmt.Errorf( + "file path is required as a positional argument " + + "or via --file", + ) + } + fc, err := resolveFilesContext(ctx, &flags.filesFlags) if err != nil { return err @@ -542,9 +574,8 @@ Agent details are automatically resolved from the azd environment.`, } addFilesFlags(cmd, &flags.filesFlags) - cmd.Flags().StringVarP(&filePath, "file", "f", "", "Remote file or directory path to remove") - _ = cmd.MarkFlagRequired("file") - cmd.Flags().BoolVar(&flags.recursive, "recursive", false, "Recursively remove directories and their contents") + cmd.Flags().StringVarP(&filePath, "file", "f", "", "Remote file or directory path to delete") + cmd.Flags().BoolVar(&flags.recursive, "recursive", false, "Recursively delete directories and their contents") return cmd } @@ -586,7 +617,7 @@ func newFilesMkdirCommand() *cobra.Command { var dirPath string cmd := &cobra.Command{ - Use: "mkdir", + Use: "mkdir [dir]", Short: "Create a directory in a hosted agent session.", Long: `Create a directory in a hosted agent session. @@ -595,14 +626,25 @@ Parent directories are created as needed. Agent details are automatically resolved from the azd environment.`, Example: ` # Create a directory (agent auto-detected) - azd ai agent files mkdir --dir /data/output + azd ai agent files mkdir /data/output - # Create with explicit session - azd ai agent files mkdir --dir /data/output --session `, + # Create with flags + azd ai agent files mkdir --dir /data/output --session-id `, + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) setupDebugLogging(cmd.Flags()) + if len(args) > 0 && dirPath == "" { + dirPath = args[0] + } + if dirPath == "" { + return fmt.Errorf( + "directory path is required as a positional " + + "argument or via --dir", + ) + } + fc, err := resolveFilesContext(ctx, flags) if err != nil { return err @@ -620,7 +662,6 @@ Agent details are automatically resolved from the azd environment.`, addFilesFlags(cmd, flags) cmd.Flags().StringVarP(&dirPath, "dir", "d", "", "Remote directory path to create") - _ = cmd.MarkFlagRequired("dir") return cmd } @@ -680,7 +721,7 @@ Agent details are automatically resolved from the azd environment.`, azd ai agent files stat /data/output.csv --output table # Get metadata with explicit session - azd ai agent files stat /data/output.csv --session `, + azd ai agent files stat /data/output.csv --session-id `, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/files_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/files_test.go index 02291665c52..bb5c75d728c 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/files_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/files_test.go @@ -43,7 +43,7 @@ func TestFilesCommand_HasSubcommands(t *testing.T) { assert.Contains(t, names, "upload") assert.Contains(t, names, "download") assert.Contains(t, names, "list") - assert.Contains(t, names, "remove") + assert.Contains(t, names, "delete") } func TestFilesUploadCommand_MissingFile(t *testing.T) { @@ -59,7 +59,7 @@ func TestFilesUploadCommand_MissingFile(t *testing.T) { func TestFilesUploadCommand_HasFlags(t *testing.T) { cmd := newFilesUploadCommand() - for _, name := range []string{"file", "target-path", "agent-name", "session"} { + for _, name := range []string{"file", "target-path", "agent-name", "session-id"} { f := cmd.Flags().Lookup(name) require.NotNil(t, f, "expected flag %q", name) assert.Equal(t, "", f.DefValue) @@ -79,7 +79,7 @@ func TestFilesDownloadCommand_MissingFile(t *testing.T) { func TestFilesDownloadCommand_HasFlags(t *testing.T) { cmd := newFilesDownloadCommand() - for _, name := range []string{"file", "target-path", "agent-name", "session"} { + for _, name := range []string{"file", "target-path", "agent-name", "session-id"} { f := cmd.Flags().Lookup(name) require.NotNil(t, f, "expected flag %q", name) assert.Equal(t, "", f.DefValue) @@ -100,7 +100,7 @@ func TestFilesListCommand_OptionalRemotePath(t *testing.T) { assert.NotNil(t, cmd.Args) } -func TestFilesRemoveCommand_MissingFile(t *testing.T) { +func TestFilesDeleteCommand_MissingFile(t *testing.T) { cmd := newFilesRemoveCommand() // Missing required --file flag @@ -110,10 +110,10 @@ func TestFilesRemoveCommand_MissingFile(t *testing.T) { assert.Contains(t, err.Error(), "file") } -func TestFilesRemoveCommand_HasFlags(t *testing.T) { +func TestFilesDeleteCommand_HasFlags(t *testing.T) { cmd := newFilesRemoveCommand() - for _, name := range []string{"file", "recursive", "agent-name", "session"} { + for _, name := range []string{"file", "recursive", "agent-name", "session-id"} { f := cmd.Flags().Lookup(name) require.NotNil(t, f, "expected flag %q", name) } @@ -135,7 +135,7 @@ func TestFilesMkdirCommand_MissingDir(t *testing.T) { func TestFilesMkdirCommand_HasFlags(t *testing.T) { cmd := newFilesMkdirCommand() - for _, name := range []string{"dir", "agent-name", "session"} { + for _, name := range []string{"dir", "agent-name", "session-id"} { f := cmd.Flags().Lookup(name) require.NotNil(t, f, "expected flag %q", name) assert.Equal(t, "", f.DefValue) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go index 69dc206c924..7698891d082 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor.go @@ -41,7 +41,7 @@ func newMonitorCommand() *cobra.Command { Long: `Monitor logs from a hosted agent. Streams console output (stdout/stderr) or system events from an agent session or container. -Use --session to stream logs for a specific session, or omit it to use the container logstream. +Use --session-id to stream logs for a specific session, or omit it to use the container logstream. Use --follow to stream logs in real-time, or omit it to fetch recent logs and exit. This is useful for troubleshooting agent startup issues or monitoring agent behavior. @@ -55,10 +55,10 @@ configuration and the current azd environment. Optionally specify the service na azd ai agent monitor my-agent # Stream session logs - azd ai agent monitor --session + azd ai agent monitor --session-id # Stream session logs in real-time - azd ai agent monitor --session --follow + azd ai agent monitor --session-id --follow # Fetch system event logs from container azd ai agent monitor --type system`, @@ -107,7 +107,7 @@ configuration and the current azd environment. Optionally specify the service na return exterrors.Validation( exterrors.CodeInvalidSessionId, "VNext agents are currently enabled and require a session ID for log streaming.", - "Specify the session ID using --session, or run `azd ai agent invoke` first to create one", + "Specify the session ID using --session-id, or run `azd ai agent invoke` first to create one", ) } flags.sessionID = sessionID @@ -123,7 +123,7 @@ configuration and the current azd environment. Optionally specify the service na }, } - cmd.Flags().StringVarP(&flags.sessionID, "session", "s", "", "Session ID to stream logs for") + cmd.Flags().StringVarP(&flags.sessionID, "session-id", "s", "", "Session ID to stream logs for") cmd.Flags().BoolVarP(&flags.follow, "follow", "f", false, "Stream logs in real-time") cmd.Flags().IntVarP(&flags.tail, "tail", "l", 50, "Number of trailing log lines to fetch (1-300)") cmd.Flags().StringVarP(&flags.logType, "type", "t", "console", diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor_test.go index 93e7d1d4418..f12c7a22f59 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/monitor_test.go @@ -100,16 +100,16 @@ func TestMonitorCommand_DefaultValues(t *testing.T) { follow, _ := cmd.Flags().GetBool("follow") assert.Equal(t, false, follow) - session, _ := cmd.Flags().GetString("session") + session, _ := cmd.Flags().GetString("session-id") assert.Equal(t, "", session) } func TestMonitorCommand_SessionFlagRegistered(t *testing.T) { cmd := newMonitorCommand() - // The --session / -s flag must be defined - f := cmd.Flags().Lookup("session") - require.NotNil(t, f, "--session flag should be registered") + // The --session-id / -s flag must be defined + f := cmd.Flags().Lookup("session-id") + require.NotNil(t, f, "--session-id flag should be registered") assert.Equal(t, "s", f.Shorthand) } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go index ecd9aee52c9..176312790a1 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go @@ -78,6 +78,7 @@ func NewRootCommand() *cobra.Command { rootCmd.AddCommand(newShowCommand()) rootCmd.AddCommand(newMonitorCommand()) rootCmd.AddCommand(newFilesCommand()) + rootCmd.AddCommand(newSessionCommand()) return rootCmd } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/session.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/session.go new file mode 100644 index 00000000000..83ba06863f8 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/session.go @@ -0,0 +1,674 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "text/tabwriter" + "time" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents/agent_api" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// sessionFlags holds common flags shared by all session subcommands. +type sessionFlags struct { + agentName string + output string +} + +func newSessionCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "sessions", + Short: "Manage sessions for a hosted agent endpoint.", + Hidden: !isVNextEnabled(context.Background()), + Long: `Manage sessions for a hosted agent endpoint. + +Create, show, list, and delete hosted agent sessions. +Sessions provide persistent compute and filesystem state for +hosted agent invocations. + +Agent details are automatically resolved from the azd environment. +Use --agent-name to select a specific agent when the project has +multiple azure.ai.agent services.`, + } + + // PersistentPreRunE is set outside the struct literal so the closure + // captures the outer cmd variable. When a subcommand runs (e.g. + // "sessions create"), Cobra passes the leaf command as the function + // parameter. Using cmd.Parent() here reaches the root command; + // using the parameter's Parent() would return this session command + // itself, causing infinite recursion. + cmd.PersistentPreRunE = func(childCmd *cobra.Command, args []string) error { + if parent := cmd.Parent(); parent != nil && + parent.PersistentPreRunE != nil { + if err := parent.PersistentPreRunE(childCmd, args); err != nil { + return err + } + } + + ctx := azdext.WithAccessToken(childCmd.Context()) + if !isVNextEnabled(ctx) { + return fmt.Errorf( + "session commands require hosted agent vnext to be enabled\n\n" + + "Set 'enableHostedAgentVNext' to 'true' in your azd " + + "environment or as an OS environment variable.", + ) + } + return nil + } + + cmd.AddCommand(newSessionCreateCommand()) + cmd.AddCommand(newSessionShowCommand()) + cmd.AddCommand(newSessionDeleteCommand()) + cmd.AddCommand(newSessionListCommand()) + + return cmd +} + +// addSessionFlags registers the common flags on a cobra command. +func addSessionFlags(cmd *cobra.Command, flags *sessionFlags) { + cmd.Flags().StringVarP( + &flags.agentName, "agent-name", "n", "", + "Agent name (matches azure.yaml service name; "+ + "auto-detected when only one exists)", + ) + cmd.Flags().StringVarP( + &flags.output, "output", "o", "json", + "Output format (json or table)", + ) +} + +// sessionContext holds the resolved agent context for session operations. +type sessionContext struct { + endpoint string + agentName string + version string // from AGENT_{SERVICE}_VERSION env var +} + +// resolveSessionContext resolves the agent name, version, and project endpoint. +func resolveSessionContext( + ctx context.Context, agentName string, +) (*sessionContext, error) { + azdClient, err := azdext.NewAzdClient() + if err != nil { + return nil, fmt.Errorf("failed to create azd client: %w", err) + } + defer azdClient.Close() + + name := agentName + var version string + + if info, err := resolveAgentServiceFromProject( + ctx, azdClient, name, rootFlags.NoPrompt, + ); err == nil { + if name == "" && info.AgentName != "" { + name = info.AgentName + } + if info.Version != "" { + version = info.Version + } + } + + if name == "" { + return nil, exterrors.Validation( + exterrors.CodeInvalidAgentName, + "agent name is required but could not be resolved", + "provide --agent-name or define an azure.ai.agent "+ + "service in azure.yaml and run 'azd up'", + ) + } + + endpoint, err := resolveAgentEndpoint(ctx, "", "") + if err != nil { + return nil, err + } + + return &sessionContext{ + endpoint: endpoint, + agentName: name, + version: version, + }, nil +} + +// --------------------------------------------------------------------------- +// session create +// --------------------------------------------------------------------------- + +type sessionCreateFlags struct { + sessionFlags + sessionID string + version string + isolationKey string +} + +func newSessionCreateCommand() *cobra.Command { + flags := &sessionCreateFlags{} + + cmd := &cobra.Command{ + Use: "create [agent-name] [version] [isolation-key]", + Short: "Create a new session for a hosted agent.", + Long: `Create a new session for a hosted agent endpoint. + +Provisions a session with a persistent filesystem. The session +is ready for invocations once the command completes. + +The agent name is auto-detected when only one azure.ai.agent service exists +in azure.yaml. The version defaults to the deployed agent version from the +azd environment (AGENT_{SERVICE}_VERSION) when omitted. +The isolation key is derived from the Entra token by default. + +Positional arguments can be used instead of flags: + azd ai agent sessions create [agent-name] [version] [isolation-key]`, + Example: ` # Create a session (auto-detect agent, latest version) + azd ai agent sessions create + + # Create a session for a specific agent + azd ai agent sessions create my-agent + + # Create a session backed by agent version 3 + azd ai agent sessions create my-agent 3 + + # Create with flags + azd ai agent sessions create --agent-name my-agent --version 3 + + # Create with a specific session ID + azd ai agent sessions create --session-id my-session`, + Args: cobra.MaximumNArgs(3), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + // Positional args fill in missing flags: [agent-name] [version] [isolation-key] + switch len(args) { + case 3: + if flags.isolationKey == "" { + flags.isolationKey = args[2] + } + fallthrough + case 2: + if flags.version == "" { + flags.version = args[1] + } + fallthrough + case 1: + if flags.agentName == "" { + flags.agentName = args[0] + } + } + + sc, err := resolveSessionContext(ctx, flags.agentName) + if err != nil { + return err + } + + // Resolve version: flag > env var > error + version := flags.version + if version == "" { + version = sc.version + } + if version == "" { + return exterrors.Validation( + exterrors.CodeInvalidAgentVersion, + "agent version is required to create a session "+ + "but could not be resolved", + "provide --version, pass the version as a "+ + "positional argument, or deploy the agent "+ + "with 'azd up' to set it automatically", + ) + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := agent_api.NewAgentClient( + sc.endpoint, credential, + ) + + request := &agent_api.CreateAgentSessionRequest{ + VersionIndicator: &agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: version, + }, + } + if flags.sessionID != "" { + request.AgentSessionID = &flags.sessionID + } + + session, err := client.CreateSession( + ctx, + sc.agentName, + flags.isolationKey, + request, + DefaultVNextAgentAPIVersion, + ) + if err != nil { + return exterrors.ServiceFromAzure( + err, exterrors.OpCreateSession, + ) + } + + // Persist session ID for reuse by invoke + persistSessionID(ctx, sc.agentName, session.AgentSessionID) + + return printSession(session, flags.output) + }, + } + + addSessionFlags(cmd, &flags.sessionFlags) + cmd.Flags().StringVar( + &flags.sessionID, "session-id", "", + "Optional caller-provided session ID "+ + "(auto-generated if omitted)", + ) + cmd.Flags().StringVar( + &flags.version, "version", "", + "Agent version to back the session "+ + "(auto-resolved from azd environment if omitted)", + ) + cmd.Flags().StringVar( + &flags.isolationKey, "isolation-key", "", + "Isolation key for session ownership "+ + "(derived from Entra token by default)", + ) + + return cmd +} + +// --------------------------------------------------------------------------- +// session show +// --------------------------------------------------------------------------- + +type sessionShowFlags struct { + sessionFlags +} + +func newSessionShowCommand() *cobra.Command { + flags := &sessionShowFlags{} + + cmd := &cobra.Command{ + Use: "show ", + Short: "Show details of a session.", + Long: `Show details of a hosted agent session. + +Retrieves the current status, version indicator, and timestamps for the +specified session.`, + Example: ` # Show session details + azd ai agent sessions show my-session + + # Show in table format + azd ai agent sessions show my-session --output table`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + sessionID := args[0] + + sc, err := resolveSessionContext(ctx, flags.agentName) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := agent_api.NewAgentClient( + sc.endpoint, credential, + ) + + session, err := client.GetSession( + ctx, + sc.agentName, + sessionID, + DefaultVNextAgentAPIVersion, + ) + if err != nil { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok && + respErr.StatusCode == http.StatusNotFound { + return exterrors.Validation( + exterrors.CodeSessionNotFound, + fmt.Sprintf( + "session %q not found or has been deleted", + sessionID, + ), + "use 'azd ai agent sessions list' to see "+ + "available sessions", + ) + } + return exterrors.ServiceFromAzure( + err, exterrors.OpGetSession, + ) + } + + return printSession(session, flags.output) + }, + } + + addSessionFlags(cmd, &flags.sessionFlags) + + return cmd +} + +// --------------------------------------------------------------------------- +// session delete +// --------------------------------------------------------------------------- + +type sessionDeleteFlags struct { + sessionFlags + isolationKey string +} + +func newSessionDeleteCommand() *cobra.Command { + flags := &sessionDeleteFlags{} + + cmd := &cobra.Command{ + Use: "delete ", + Short: "Delete a session.", + Long: `Delete a hosted agent session synchronously. + +Terminates the hosted agent session and deletes the persistent filesystem +volume. Returns once cleanup is complete. + +The isolation key is derived from the Entra token by default.`, + Example: ` # Delete a session + azd ai agent sessions delete my-session + + # Delete with an explicit isolation key + azd ai agent sessions delete my-session --isolation-key sk-abc123`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + sessionID := args[0] + + sc, err := resolveSessionContext(ctx, flags.agentName) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := agent_api.NewAgentClient( + sc.endpoint, credential, + ) + + err = client.DeleteSession( + ctx, + sc.agentName, + sessionID, + flags.isolationKey, + DefaultVNextAgentAPIVersion, + ) + if err != nil { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok && + respErr.StatusCode == http.StatusNotFound { + return exterrors.Validation( + exterrors.CodeSessionNotFound, + fmt.Sprintf( + "session %q not found or has already been deleted", + sessionID, + ), + "use 'azd ai agent sessions list' to see "+ + "available sessions", + ) + } + return exterrors.ServiceFromAzure( + err, exterrors.OpDeleteSession, + ) + } + + fmt.Printf( + "Session %q deleted from agent %q.\n", + sessionID, sc.agentName, + ) + return nil + }, + } + + addSessionFlags(cmd, &flags.sessionFlags) + cmd.Flags().StringVar( + &flags.isolationKey, "isolation-key", "", + "Isolation key for session ownership "+ + "(derived from Entra token by default)", + ) + + return cmd +} + +// --------------------------------------------------------------------------- +// session list +// --------------------------------------------------------------------------- + +type sessionListFlags struct { + sessionFlags + limit int32 + paginationToken string +} + +func newSessionListCommand() *cobra.Command { + flags := &sessionListFlags{} + + cmd := &cobra.Command{ + Use: "list", + Short: "List sessions for a hosted agent.", + Long: `List sessions for a hosted agent endpoint. + +Returns a paged list of sessions with their status, version, and timestamps.`, + Example: ` # List all sessions + azd ai agent sessions list + + # List with a page size limit + azd ai agent sessions list --limit 10 + + # List in table format + azd ai agent sessions list --output table`, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + sc, err := resolveSessionContext(ctx, flags.agentName) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := agent_api.NewAgentClient( + sc.endpoint, credential, + ) + + var limit *int32 + if cmd.Flags().Changed("limit") { + limit = &flags.limit + } + + var token *string + if flags.paginationToken != "" { + token = &flags.paginationToken + } + + result, err := client.ListSessions( + ctx, + sc.agentName, + limit, + token, + DefaultVNextAgentAPIVersion, + ) + if err != nil { + return exterrors.ServiceFromAzure( + err, exterrors.OpListSessions, + ) + } + + return printSessionList(result, flags.output) + }, + } + + addSessionFlags(cmd, &flags.sessionFlags) + cmd.Flags().Int32Var( + &flags.limit, "limit", 0, + "Maximum number of sessions to return", + ) + cmd.Flags().StringVar( + &flags.paginationToken, "pagination-token", "", + "Continuation token from a previous list response", + ) + + return cmd +} + +// --------------------------------------------------------------------------- +// Output formatting +// --------------------------------------------------------------------------- + +func printSession( + session *agent_api.AgentSessionResource, format string, +) error { + switch format { + case "table": + return printSessionTable(session) + default: + return printSessionJSON(session) + } +} + +func printSessionJSON(session *agent_api.AgentSessionResource) error { + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal session to JSON: %w", err) + } + fmt.Println(string(data)) + return nil +} + +func printSessionTable(session *agent_api.AgentSessionResource) error { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "FIELD\tVALUE") + fmt.Fprintln(w, "-----\t-----") + + fmt.Fprintf(w, "Session ID\t%s\n", session.AgentSessionID) + fmt.Fprintf(w, "Status\t%s\n", session.Status) + fmt.Fprintf( + w, "Version\t%s (type: %s)\n", + session.VersionIndicator.AgentVersion, + session.VersionIndicator.Type, + ) + fmt.Fprintf( + w, "Created At\t%s\n", formatUnixTimestamp(session.CreatedAt), + ) + fmt.Fprintf( + w, "Last Accessed\t%s\n", + formatUnixTimestamp(session.LastAccessedAt), + ) + fmt.Fprintf( + w, "Expires At\t%s\n", formatUnixTimestamp(session.ExpiresAt), + ) + + return w.Flush() +} + +func printSessionList( + result *agent_api.SessionListResult, format string, +) error { + switch format { + case "table": + return printSessionListTable(result) + default: + return printSessionListJSON(result) + } +} + +func printSessionListJSON(result *agent_api.SessionListResult) error { + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return fmt.Errorf( + "failed to marshal session list to JSON: %w", err, + ) + } + fmt.Println(string(data)) + return nil +} + +func printSessionListTable(result *agent_api.SessionListResult) error { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln( + w, + "SESSION ID\tSTATUS\tVERSION\tCREATED\tLAST ACCESSED\tEXPIRES", + ) + fmt.Fprintln( + w, + "----------\t------\t-------\t-------\t-------------\t-------", + ) + + for _, s := range result.Data { + fmt.Fprintf( + w, "%s\t%s\t%s\t%s\t%s\t%s\n", + s.AgentSessionID, + s.Status, + s.VersionIndicator.AgentVersion, + formatUnixTimestamp(s.CreatedAt), + formatUnixTimestamp(s.LastAccessedAt), + formatUnixTimestamp(s.ExpiresAt), + ) + } + + if err := w.Flush(); err != nil { + return err + } + + if result.PaginationToken != nil && *result.PaginationToken != "" { + fmt.Printf( + "\nMore results available. "+ + "Use --pagination-token %q to fetch the next page.\n", + *result.PaginationToken, + ) + } + + return nil +} + +// formatUnixTimestamp converts a Unix epoch timestamp (seconds) to a +// human-readable UTC string. Returns "-" for zero values. +func formatUnixTimestamp(epoch int64) string { + if epoch == 0 { + return "-" + } + return time.Unix(epoch, 0).UTC().Format(time.RFC3339) +} + +// persistSessionID saves the session ID to .foundry-agent.json for reuse. +func persistSessionID(ctx context.Context, agentName, sessionID string) { + if sessionID == "" { + return + } + + azdClient, err := azdext.NewAzdClient() + if err != nil { + return + } + defer azdClient.Close() + + saveContextValue(ctx, azdClient, agentName, sessionID, "sessions") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/session_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/session_test.go new file mode 100644 index 00000000000..7f20c2e0a8f --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/session_test.go @@ -0,0 +1,498 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "testing" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents/agent_api" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Command structure tests +// --------------------------------------------------------------------------- + +func TestSessionCommand_HasSubcommands(t *testing.T) { + cmd := newSessionCommand() + + names := make([]string, 0, len(cmd.Commands())) + for _, sub := range cmd.Commands() { + names = append(names, sub.Name()) + } + + assert.Contains(t, names, "create") + assert.Contains(t, names, "show") + assert.Contains(t, names, "delete") + assert.Contains(t, names, "list") +} + +func TestSessionShowCommand_RequiresOneArg(t *testing.T) { + cmd := newSessionShowCommand() + + assert.NoError(t, cmd.Args(cmd, []string{"my-session"})) + assert.Error(t, cmd.Args(cmd, []string{})) + assert.Error(t, cmd.Args(cmd, []string{"a", "b"})) +} + +func TestSessionDeleteCommand_RequiresOneArg(t *testing.T) { + cmd := newSessionDeleteCommand() + + assert.NoError(t, cmd.Args(cmd, []string{"my-session"})) + assert.Error(t, cmd.Args(cmd, []string{})) + assert.Error(t, cmd.Args(cmd, []string{"a", "b"})) +} + +func TestSessionCreateCommand_DefaultFlags(t *testing.T) { + cmd := newSessionCreateCommand() + + output, _ := cmd.Flags().GetString("output") + assert.Equal(t, "json", output) + + sessionID, _ := cmd.Flags().GetString("session-id") + assert.Equal(t, "", sessionID) + + version, _ := cmd.Flags().GetString("version") + assert.Equal(t, "", version) + + isolationKey, _ := cmd.Flags().GetString("isolation-key") + assert.Equal(t, "", isolationKey) +} + +func TestSessionCreateCommand_AcceptsPositionalArgs(t *testing.T) { + cmd := newSessionCreateCommand() + + assert.NoError(t, cmd.Args(cmd, []string{})) + assert.NoError(t, cmd.Args(cmd, []string{"my-agent"})) + assert.NoError(t, cmd.Args(cmd, []string{"my-agent", "3"})) + assert.NoError(t, cmd.Args(cmd, []string{"my-agent", "3", "sk-key"})) + assert.Error(t, cmd.Args(cmd, []string{"a", "b", "c", "d"})) +} + +func TestSessionListCommand_DefaultFlags(t *testing.T) { + cmd := newSessionListCommand() + + output, _ := cmd.Flags().GetString("output") + assert.Equal(t, "json", output) + + limit, _ := cmd.Flags().GetInt32("limit") + assert.Equal(t, int32(0), limit) + + token, _ := cmd.Flags().GetString("pagination-token") + assert.Equal(t, "", token) +} + +// --------------------------------------------------------------------------- +// Output formatting tests +// --------------------------------------------------------------------------- + +func TestPrintSessionJSON(t *testing.T) { + session := &agent_api.AgentSessionResource{ + AgentSessionID: "test-session-1", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + LastAccessedAt: 1710234567, + ExpiresAt: 1712826567, + } + + err := printSessionJSON(session) + require.NoError(t, err) +} + +func TestPrintSessionJSON_Format(t *testing.T) { + session := &agent_api.AgentSessionResource{ + AgentSessionID: "test-session-1", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + LastAccessedAt: 1710234567, + ExpiresAt: 1712826567, + } + + data, err := json.MarshalIndent(session, "", " ") + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "test-session-1", result["agent_session_id"]) + assert.Equal(t, "active", result["status"]) + + vi := result["version_indicator"].(map[string]any) + assert.Equal(t, "version_ref", vi["type"]) + assert.Equal(t, "3", vi["agent_version"]) + + assert.Equal(t, float64(1710234567), result["created_at"]) +} + +func TestPrintSessionTable(t *testing.T) { + session := &agent_api.AgentSessionResource{ + AgentSessionID: "test-session-1", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + LastAccessedAt: 1710234567, + ExpiresAt: 1712826567, + } + + err := printSessionTable(session) + require.NoError(t, err) +} + +func TestPrintSessionListJSON(t *testing.T) { + nextToken := "abc123" + result := &agent_api.SessionListResult{ + Data: []agent_api.AgentSessionResource{ + { + AgentSessionID: "session-1", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + LastAccessedAt: 1710234567, + ExpiresAt: 1712826567, + }, + { + AgentSessionID: "session-2", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "1", + }, + Status: agent_api.AgentSessionStatusIdle, + CreatedAt: 1710230000, + LastAccessedAt: 1710231000, + ExpiresAt: 1712822000, + }, + }, + PaginationToken: &nextToken, + } + + err := printSessionListJSON(result) + require.NoError(t, err) +} + +func TestPrintSessionListTable(t *testing.T) { + result := &agent_api.SessionListResult{ + Data: []agent_api.AgentSessionResource{ + { + AgentSessionID: "session-1", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + }, + }, + } + + err := printSessionListTable(result) + require.NoError(t, err) +} + +func TestPrintSessionListTable_Empty(t *testing.T) { + result := &agent_api.SessionListResult{ + Data: []agent_api.AgentSessionResource{}, + } + + err := printSessionListTable(result) + require.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// Timestamp formatting tests +// --------------------------------------------------------------------------- + +func TestFormatUnixTimestamp(t *testing.T) { + tests := []struct { + name string + epoch int64 + expected string + }{ + { + name: "zero returns dash", + epoch: 0, + expected: "-", + }, + { + name: "known timestamp", + epoch: 1710234567, + expected: "2024-03-12T09:09:27Z", + }, + { + name: "unix epoch start", + epoch: 1, + expected: "1970-01-01T00:00:01Z", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatUnixTimestamp(tt.epoch) + assert.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// Model serialization tests +// --------------------------------------------------------------------------- + +func TestAgentSessionResourceJSON_RoundTrip(t *testing.T) { + session := agent_api.AgentSessionResource{ + AgentSessionID: "my-session", + VersionIndicator: agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + Status: agent_api.AgentSessionStatusActive, + CreatedAt: 1710234567, + LastAccessedAt: 1710234567, + ExpiresAt: 1712826567, + } + + data, err := json.Marshal(session) + require.NoError(t, err) + + var decoded agent_api.AgentSessionResource + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, session, decoded) +} + +func TestCreateAgentSessionRequestJSON(t *testing.T) { + sessionID := "my-session" + request := agent_api.CreateAgentSessionRequest{ + AgentSessionID: &sessionID, + VersionIndicator: &agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + } + + data, err := json.Marshal(request) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "my-session", result["agent_session_id"]) + vi := result["version_indicator"].(map[string]any) + assert.Equal(t, "version_ref", vi["type"]) + assert.Equal(t, "3", vi["agent_version"]) +} + +func TestCreateAgentSessionRequestJSON_NoSessionID(t *testing.T) { + request := agent_api.CreateAgentSessionRequest{ + VersionIndicator: &agent_api.VersionIndicator{ + Type: "version_ref", + AgentVersion: "5", + }, + } + + data, err := json.Marshal(request) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + _, hasID := result["agent_session_id"] + assert.False(t, hasID, "agent_session_id should be omitted when nil") +} + +func TestCreateAgentSessionRequestJSON_NoVersion(t *testing.T) { + request := agent_api.CreateAgentSessionRequest{} + + data, err := json.Marshal(request) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + _, hasVI := result["version_indicator"] + assert.False(t, hasVI, "version_indicator should be omitted when nil") + + _, hasID := result["agent_session_id"] + assert.False(t, hasID, "agent_session_id should be omitted when nil") +} + +func TestSessionListResultJSON_WithPaginationToken(t *testing.T) { + token := "next-page" + result := agent_api.SessionListResult{ + Data: []agent_api.AgentSessionResource{ + { + AgentSessionID: "s1", + Status: agent_api.AgentSessionStatusActive, + }, + }, + PaginationToken: &token, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, "next-page", decoded["pagination_token"]) + items := decoded["data"].([]any) + assert.Len(t, items, 1) +} + +func TestSessionListResultJSON_NoPaginationToken(t *testing.T) { + result := agent_api.SessionListResult{ + Data: []agent_api.AgentSessionResource{}, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + _, hasToken := decoded["pagination_token"] + assert.False(t, hasToken, "pagination_token should be omitted when nil") +} + +// --------------------------------------------------------------------------- +// Error classification tests — mirrors logic from session show / delete +// --------------------------------------------------------------------------- + +// classifyGetSessionError reproduces the error handling from newSessionShowCommand +// so we can test the classification without an end-to-end context. +func classifyGetSessionError(err error, sessionID string) error { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok && + respErr.StatusCode == http.StatusNotFound { + return exterrors.Validation( + exterrors.CodeSessionNotFound, + fmt.Sprintf( + "session %q not found or has been deleted", + sessionID, + ), + "use 'azd ai agent sessions list' to see "+ + "available sessions", + ) + } + return exterrors.ServiceFromAzure(err, exterrors.OpGetSession) +} + +func TestGetSession_404_ProducesValidationError(t *testing.T) { + azErr := &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: "session_not_found", + } + + result := classifyGetSessionError(azErr, "my-session-id") + require.Error(t, result) + + // Should produce a LocalError (validation), not a ServiceError. + var localErr *azdext.LocalError + require.True( + t, errors.As(result, &localErr), + "404 should produce a LocalError, got: %T", result, + ) + assert.Equal(t, exterrors.CodeSessionNotFound, localErr.Code) + assert.Contains(t, localErr.Message, "my-session-id") + assert.Contains(t, localErr.Message, "not found") +} + +func TestGetSession_500_ProducesServiceError(t *testing.T) { + azErr := &azcore.ResponseError{ + StatusCode: http.StatusInternalServerError, + ErrorCode: "internal_error", + } + + result := classifyGetSessionError(azErr, "sess-1") + require.Error(t, result) + + // Non-404 errors remain as ServiceError. + var svcErr *azdext.ServiceError + require.True( + t, errors.As(result, &svcErr), + "500 should produce a ServiceError, got: %T", result, + ) +} + +// classifyDeleteSessionError reproduces the error handling from newSessionDeleteCommand +// so we can test the classification without an end-to-end context. +func classifyDeleteSessionError(err error, sessionID string) error { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok && + respErr.StatusCode == http.StatusNotFound { + return exterrors.Validation( + exterrors.CodeSessionNotFound, + fmt.Sprintf( + "session %q not found or has already been deleted", + sessionID, + ), + "use 'azd ai agent sessions list' to see "+ + "available sessions", + ) + } + return exterrors.ServiceFromAzure(err, exterrors.OpDeleteSession) +} + +func TestDeleteSession_404_ProducesValidationError(t *testing.T) { + azErr := &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: "session_not_found", + } + + result := classifyDeleteSessionError(azErr, "my-session-id") + require.Error(t, result) + + // Should produce a LocalError (validation), not a ServiceError. + var localErr *azdext.LocalError + require.True( + t, errors.As(result, &localErr), + "404 should produce a LocalError, got: %T", result, + ) + assert.Equal(t, exterrors.CodeSessionNotFound, localErr.Code) + assert.Contains(t, localErr.Message, "my-session-id") + assert.Contains(t, localErr.Message, "not found") +} + +func TestDeleteSession_500_ProducesServiceError(t *testing.T) { + azErr := &azcore.ResponseError{ + StatusCode: http.StatusInternalServerError, + ErrorCode: "internal_error", + } + + result := classifyDeleteSessionError(azErr, "sess-1") + require.Error(t, result) + + // Non-404 errors remain as ServiceError. + var svcErr *azdext.ServiceError + require.True( + t, errors.As(result, &svcErr), + "500 should produce a ServiceError, got: %T", result, + ) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index fe700c880e6..ed5c127afc0 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -20,6 +20,8 @@ const ( CodeInvalidAiProjectId = "invalid_ai_project_id" CodeInvalidServiceConfig = "invalid_service_config" CodeInvalidAgentRequest = "invalid_agent_request" + CodeInvalidAgentName = "invalid_agent_name" + CodeInvalidAgentVersion = "invalid_agent_version" CodeInvalidSessionId = "invalid_session_id" CodeInvalidParameter = "invalid_parameter" CodeUnsupportedHost = "unsupported_host" @@ -80,6 +82,11 @@ const ( CodeModelResolutionFailed = "model_resolution_failed" ) +// Error codes for session errors. +const ( + CodeSessionNotFound = "session_not_found" +) + // Error codes for file operation errors. const ( CodeFileNotFound = "file_not_found" @@ -108,4 +115,8 @@ const ( OpCreateAgent = "create_agent" OpStartContainer = "start_container" OpGetContainerOperation = "get_container_operation" + OpCreateSession = "create_session" + OpGetSession = "get_session" + OpDeleteSession = "delete_session" + OpListSessions = "list_sessions" ) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go index ee848f8dd44..14c08815b28 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/models.go @@ -699,3 +699,49 @@ type SessionFileList struct { Path string `json:"path"` Entries []SessionFileInfo `json:"entries"` } + +// --------------------------------------------------------------------------- +// Session Lifecycle Models +// --------------------------------------------------------------------------- + +// AgentSessionStatus represents the status of an agent session. +type AgentSessionStatus string + +const ( + AgentSessionStatusCreating AgentSessionStatus = "creating" + AgentSessionStatusActive AgentSessionStatus = "active" + AgentSessionStatusIdle AgentSessionStatus = "idle" + AgentSessionStatusUpdating AgentSessionStatus = "updating" + AgentSessionStatusFailed AgentSessionStatus = "failed" + AgentSessionStatusDeleting AgentSessionStatus = "deleting" + AgentSessionStatusDeleted AgentSessionStatus = "deleted" + AgentSessionStatusExpired AgentSessionStatus = "expired" +) + +// VersionIndicator determines which agent version backs a session. +type VersionIndicator struct { + Type string `json:"type"` + AgentVersion string `json:"agent_version,omitempty"` +} + +// AgentSessionResource represents an agent session. +type AgentSessionResource struct { + AgentSessionID string `json:"agent_session_id"` + VersionIndicator VersionIndicator `json:"version_indicator"` + Status AgentSessionStatus `json:"status"` + CreatedAt int64 `json:"created_at"` + LastAccessedAt int64 `json:"last_accessed_at"` + ExpiresAt int64 `json:"expires_at"` +} + +// CreateAgentSessionRequest is the request body for creating a session. +type CreateAgentSessionRequest struct { + AgentSessionID *string `json:"agent_session_id,omitempty"` + VersionIndicator *VersionIndicator `json:"version_indicator,omitempty"` +} + +// SessionListResult is the paged result for session list operations. +type SessionListResult struct { + Data []AgentSessionResource `json:"data"` + PaginationToken *string `json:"pagination_token,omitempty"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations.go index a0cfd5dbf75..cb612889551 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations.go @@ -996,7 +996,7 @@ func (c *AgentClient) UploadSessionFile( } u.Path += fmt.Sprintf( - "/agents/%s/endpoint/sessions/%s/files", + "/agents/%s/endpoint/sessions/%s/files/content", agentName, sessionID, ) @@ -1042,7 +1042,7 @@ func (c *AgentClient) DownloadSessionFile( } u.Path += fmt.Sprintf( - "/agents/%s/endpoint/sessions/%s/files", + "/agents/%s/endpoint/sessions/%s/files/content", agentName, sessionID, ) @@ -1270,3 +1270,222 @@ func (c *AgentClient) StatSessionFile( return &fileInfo, nil } + +// --------------------------------------------------------------------------- +// Session Lifecycle Operations +// --------------------------------------------------------------------------- + +// CreateSession creates a new session for an agent endpoint. +func (c *AgentClient) CreateSession( + ctx context.Context, + agentName, isolationKey string, + request *CreateAgentSessionRequest, + apiVersion string, +) (*AgentSessionResource, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += fmt.Sprintf("/agents/%s/endpoint/sessions", agentName) + + query := u.Query() + query.Set("api-version", apiVersion) + u.RawQuery = query.Encode() + + if request == nil { + request = &CreateAgentSessionRequest{} + } + + payload, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, u.String()) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), + "application/json", + ); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", "AgentEndpoints=V1Preview") + if isolationKey != "" { + req.Raw().Header.Set("x-session-isolation-key", isolationKey) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var session AgentSessionResource + if err := json.Unmarshal(body, &session); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &session, nil +} + +// GetSession retrieves a session by ID. +func (c *AgentClient) GetSession( + ctx context.Context, + agentName, sessionID, apiVersion string, +) (*AgentSessionResource, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += fmt.Sprintf( + "/agents/%s/endpoint/sessions/%s", + agentName, sessionID, + ) + + query := u.Query() + query.Set("api-version", apiVersion) + u.RawQuery = query.Encode() + + req, err := runtime.NewRequest(ctx, http.MethodGet, u.String()) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", "AgentEndpoints=V1Preview") + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var session AgentSessionResource + if err := json.Unmarshal(body, &session); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &session, nil +} + +// DeleteSession deletes a session synchronously. +func (c *AgentClient) DeleteSession( + ctx context.Context, + agentName, sessionID, isolationKey, apiVersion string, +) error { + u, err := url.Parse(c.endpoint) + if err != nil { + return fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += fmt.Sprintf( + "/agents/%s/endpoint/sessions/%s", + agentName, sessionID, + ) + + query := u.Query() + query.Set("api-version", apiVersion) + u.RawQuery = query.Encode() + + req, err := runtime.NewRequest(ctx, http.MethodDelete, u.String()) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", "AgentEndpoints=V1Preview") + if isolationKey != "" { + req.Raw().Header.Set("x-session-isolation-key", isolationKey) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode( + resp, http.StatusOK, http.StatusNoContent, + ) { + return runtime.NewResponseError(resp) + } + + return nil +} + +// ListSessions returns a list of sessions for the specified agent. +func (c *AgentClient) ListSessions( + ctx context.Context, + agentName string, + limit *int32, + paginationToken *string, + apiVersion string, +) (*SessionListResult, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += fmt.Sprintf("/agents/%s/endpoint/sessions", agentName) + + query := u.Query() + query.Set("api-version", apiVersion) + if limit != nil { + query.Set("limit", strconv.Itoa(int(*limit))) + } + if paginationToken != nil && *paginationToken != "" { + query.Set("pagination_token", *paginationToken) + } + u.RawQuery = query.Encode() + + req, err := runtime.NewRequest(ctx, http.MethodGet, u.String()) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", "AgentEndpoints=V1Preview") + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result SessionListResult + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations_test.go new file mode 100644 index 00000000000..82fb859f718 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_api/operations_test.go @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agent_api + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/require" +) + +// fakeTransport is a test HTTP transport that returns a canned response. +type fakeTransport struct { + statusCode int +} + +func (f *fakeTransport) Do(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: f.statusCode, + Header: http.Header{"Content-Type": {"application/json"}}, + Body: http.NoBody, + Request: req, + }, nil +} + +// newTestClient creates an AgentClient backed by fakeTransport (no auth). +func newTestClient(endpoint string, transport policy.Transporter) *AgentClient { + pipeline := runtime.NewPipeline( + "test", "v0.0.0-test", + runtime.PipelineOptions{}, + &policy.ClientOptions{Transport: transport}, + ) + return &AgentClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +func TestDeleteSession_Accepts200(t *testing.T) { + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeTransport{statusCode: http.StatusOK}, + ) + + err := client.DeleteSession( + t.Context(), "my-agent", "sess-1", "", "2025-11-15-preview", + ) + require.NoError(t, err, "200 OK should be treated as success") +} + +func TestDeleteSession_Accepts204(t *testing.T) { + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeTransport{statusCode: http.StatusNoContent}, + ) + + err := client.DeleteSession( + t.Context(), "my-agent", "sess-1", "", "2025-11-15-preview", + ) + require.NoError(t, err, "204 No Content should be treated as success") +} + +func TestDeleteSession_Rejects500(t *testing.T) { + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeTransport{statusCode: http.StatusInternalServerError}, + ) + + err := client.DeleteSession( + t.Context(), "my-agent", "sess-1", "", "2025-11-15-preview", + ) + require.Error(t, err, "500 should be an error") +} + +func TestGetSession_404ReturnsError(t *testing.T) { + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeTransport{statusCode: http.StatusNotFound}, + ) + + _, err := client.GetSession( + t.Context(), "my-agent", "sess-1", "2025-11-15-preview", + ) + require.Error(t, err, "404 should be an error from GetSession") +} + +// fakeBodyTransport returns a canned status code and JSON body. +type fakeBodyTransport struct { + statusCode int + body string +} + +func (f *fakeBodyTransport) Do(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: f.statusCode, + Header: http.Header{"Content-Type": {"application/json"}}, + Body: io.NopCloser(strings.NewReader(f.body)), + Request: req, + }, nil +} + +func TestCreateSession_Returns201WithBody(t *testing.T) { + body := `{ + "agent_session_id": "sess-new", + "version_indicator": {"type": "version_ref", "agent_version": "3"}, + "status": "running", + "created_at": 1700000000, + "last_accessed_at": 1700000100, + "expires_at": 1700086400 + }` + + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeBodyTransport{ + statusCode: http.StatusCreated, + body: body, + }, + ) + + session, err := client.CreateSession( + t.Context(), "my-agent", "", + &CreateAgentSessionRequest{ + VersionIndicator: &VersionIndicator{ + Type: "version_ref", + AgentVersion: "3", + }, + }, + "2025-11-15-preview", + ) + + require.NoError(t, err) + require.Equal(t, "sess-new", session.AgentSessionID) + require.Equal(t, "3", session.VersionIndicator.AgentVersion) + require.Equal(t, AgentSessionStatus("running"), session.Status) +} + +func TestListSessions_Returns200WithPagination(t *testing.T) { + body := `{ + "data": [ + { + "agent_session_id": "sess-1", + "version_indicator": {"type": "version_ref", "agent_version": "2"}, + "status": "running", + "created_at": 1700000000, + "last_accessed_at": 1700000100, + "expires_at": 1700086400 + } + ], + "pagination_token": "next-page-abc" + }` + + client := newTestClient( + "https://test.example.com/api/projects/proj", + &fakeBodyTransport{ + statusCode: http.StatusOK, + body: body, + }, + ) + + result, err := client.ListSessions( + t.Context(), "my-agent", nil, nil, "2025-11-15-preview", + ) + + require.NoError(t, err) + require.Len(t, result.Data, 1) + require.Equal(t, "sess-1", result.Data[0].AgentSessionID) + require.NotNil(t, result.PaginationToken) + require.Equal(t, "next-page-abc", *result.PaginationToken) +}