Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions guest/sshd/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,116 @@
assert.Contains(t, sockPath, "/tmp/ssh-", "agent socket should be in /tmp/ssh-*")
}

func TestAgentForwardingEndToEnd(t *testing.T) {
t.Parallel()

// 1. Create a test key and add it to an in-memory agent.
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

keyring := agent.NewKeyring()
require.NoError(t, keyring.Add(agent.AddedKey{PrivateKey: ecKey}))

// 2. Start server with agent forwarding enabled.
signer, pubKey := generateTestKeyPair(t)
_, addr := startTestServerWithConfig(t, Config{
Port: 0,
AuthorizedKeys: []ssh.PublicKey{pubKey},
Env: []string{"PATH=/usr/bin:/bin"},
DefaultUID: uint32(os.Getuid()),
DefaultGID: uint32(os.Getgid()),
DefaultUser: "testuser",
DefaultHome: os.TempDir(),
DefaultShell: "/bin/sh",
DefaultWorkDir: t.TempDir(),
AgentForwarding: true,
Logger: slog.Default(),
})

// 3. Connect SSH client.
client := dialSSH(t, addr, signer)

// 4. Register handler for auth-agent@openssh.com channels BEFORE
// requesting forwarding — otherwise the server's channel open
// will be rejected.
agentChans := client.HandleChannelOpen("auth-agent@openssh.com")
go func() {
for newCh := range agentChans {
ch, reqs, err := newCh.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)
go func() {
agent.ServeAgent(keyring, ch)

Check failure on line 455 in guest/sshd/server_test.go

View workflow job for this annotation

GitHub Actions / Lint (Linux)

Error return value of `agent.ServeAgent` is not checked (errcheck)
_ = ch.Close()
}()
}
}()

// 5. Open a session, request forwarding, run ssh-add -l.
session, err := client.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()

err = agent.RequestAgentForwarding(session)
require.NoError(t, err)

output, err := session.CombinedOutput("ssh-add -l")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the test is supposed to be run automatically, we might want to skip it if ssh-add is not available (it is by default on Ubuntu runners in github)

require.NoError(t, err)

result := string(output)
assert.NotContains(t, result, "The agent has no identities")
assert.NotContains(t, result, "Could not open a connection")
assert.Contains(t, result, "ECDSA", "expected forwarded ECDSA key in ssh-add output")
}

func TestAgentForwardingEndToEnd_NoClientHandler(t *testing.T) {
t.Parallel()

signer, pubKey := generateTestKeyPair(t)
_, addr := startTestServerWithConfig(t, Config{
Port: 0,
AuthorizedKeys: []ssh.PublicKey{pubKey},
Env: []string{"PATH=/usr/bin:/bin"},
DefaultUID: uint32(os.Getuid()),
DefaultGID: uint32(os.Getgid()),
DefaultUser: "testuser",
DefaultHome: os.TempDir(),
DefaultShell: "/bin/sh",
DefaultWorkDir: t.TempDir(),
AgentForwarding: true,
Logger: slog.Default(),
})

client := dialSSH(t, addr, signer)

// Do NOT register HandleChannelOpen — the proxy channel will be rejected.

session, err := client.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()

err = agent.RequestAgentForwarding(session)
require.NoError(t, err)

output, err := session.CombinedOutput("ssh-add -l 2>&1")
result := strings.TrimSpace(string(output))

// Without a client-side handler, ssh-add should fail.
if err == nil {
// Some versions of ssh-add exit 0 but report no agent.
assert.True(t,
strings.Contains(result, "Could not open a connection") ||
strings.Contains(result, "The agent has no identities") ||
strings.Contains(result, "Error connecting to agent") ||
strings.Contains(result, "error"),
"ssh-add should fail without client-side agent handler, got: %s", result,
)
}
// If err != nil, the command exited non-zero — that's the expected case.
}

func TestNoSocketWithoutForwardingRequest(t *testing.T) {
t.Parallel()

Expand Down
Loading