Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmd/dmsg-discovery/commands/dmsg-discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

proxyproto "github.com/pires/go-proxyproto"
"github.com/sirupsen/logrus"
"github.com/skycoin/skywire/deployment"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/buildinfo"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cmdutil"
Expand Down Expand Up @@ -206,6 +207,17 @@ Example:
cancel()
}
}()

// Serve pprof debug interface over dmsg
wl := deployment.Prod.SurveyWhitelist
if testEnvironment {
wl = deployment.Test.SurveyWhitelist
}
go func() {
if debugErr := dmsghttp.ServeDebug(ctx, dmsgDC, log, wl); debugErr != nil {
log.Errorf("dmsghttp.ServeDebug: %v", debugErr)
}
}()
}

<-ctx.Done()
Expand Down
35 changes: 35 additions & 0 deletions cmd/dmsg-server/commands/start/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@ import (

chi "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/skycoin/skywire/deployment"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/buildinfo"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cmdutil"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/metricsutil"
"github.com/spf13/cobra"

dmsgcmdutil "github.com/skycoin/dmsg/pkg/cmdutil"
"github.com/skycoin/dmsg/pkg/direct"
"github.com/skycoin/dmsg/pkg/disc"
dmsg "github.com/skycoin/dmsg/pkg/dmsg"
"github.com/skycoin/dmsg/pkg/dmsg/metrics"
"github.com/skycoin/dmsg/pkg/dmsgclient"
"github.com/skycoin/dmsg/pkg/dmsghttp"
"github.com/skycoin/dmsg/pkg/dmsgserver"
)

Expand Down Expand Up @@ -127,6 +131,37 @@ var RootCmd = &cobra.Command{
}
}()

// Serve pprof debug interface over dmsg using a direct client through ourselves
go func() {
// Wait for the dmsg server to be ready before connecting the debug client
<-srv.Ready()

serverEntry := &disc.Entry{
Version: "0.0.1",
Static: conf.PubKey,
Server: &disc.Server{
Address: conf.PublicAddress,
AvailableSessions: conf.MaxSessions,
},
}
entries := direct.GetAllEntries(cipher.PubKeys{conf.PubKey}, []*disc.Entry{serverEntry})
dClient := direct.NewClient(entries, log)

debugConfig := &dmsg.Config{
MinSessions: 0,
}
dmsgC, closeDebug, err := direct.StartDmsg(ctx, log, conf.PubKey, conf.SecKey, dClient, debugConfig)
if err != nil {
log.WithError(err).Error("failed to start debug dmsg client")
return
}
defer closeDebug()

if debugErr := dmsghttp.ServeDebug(ctx, dmsgC, log, deployment.Prod.SurveyWhitelist); debugErr != nil {
log.Errorf("dmsghttp.ServeDebug: %v", debugErr)
}
}()

<-ctx.Done()
},
}
Expand Down
99 changes: 99 additions & 0 deletions pkg/dmsghttp/debug.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Package dmsghttp pkg/dmsghttp/debug.go
package dmsghttp

import (
"context"
"fmt"
"net"
"net/http"
"net/http/pprof"
"time"

"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"

dmsg "github.com/skycoin/dmsg/pkg/dmsg"
)

// DefaultDebugPort is the dmsg port used for serving debug/pprof endpoints.
const DefaultDebugPort = uint16(81)

// DebugMux returns an http.ServeMux with standard pprof endpoints registered.
func DebugMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
for _, p := range []string{"heap", "goroutine", "threadcreate", "block", "mutex", "allocs"} {
mux.Handle("/debug/pprof/"+p, pprof.Handler(p))
}
return mux
}

// WhitelistMiddleware wraps an http.Handler with public-key-based access control.
// When serving over dmsg, RemoteAddr is in the format "<pk>:<port>".
// If whitelistedPKs is empty, all requests are allowed.
func WhitelistMiddleware(whitelistedPKs []cipher.PubKey, next http.Handler) http.Handler {
if len(whitelistedPKs) == 0 {
return next
}
allowed := make(map[string]struct{}, len(whitelistedPKs))
for _, pk := range whitelistedPKs {
allowed[pk.String()] = struct{}{}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remotePK, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
return
}
if _, ok := allowed[remotePK]; !ok {
http.Error(w, "401 Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

// ServeDebug serves pprof endpoints over dmsg on DefaultDebugPort, gated by the
// provided whitelist public keys. It blocks until the context is canceled or
// an error occurs.
func ServeDebug(ctx context.Context, dmsgC *dmsg.Client, log *logging.Logger, whitelistPKs []cipher.PubKey) error {
handler := WhitelistMiddleware(whitelistPKs, DebugMux())

lis, err := dmsgC.Listen(DefaultDebugPort)
if err != nil {
return fmt.Errorf("debug dmsg listen on port %d: %w", DefaultDebugPort, err)
}

log.WithField("dmsg_addr", fmt.Sprintf("dmsg://%v", lis.Addr().String())).
Info("Serving debug/pprof over dmsg")

srv := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 60 * time.Second, // pprof profile collection takes 30s
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
MaxHeaderBytes: 1 << 14, // 16KB
Handler: handler,
}

done := make(chan struct{})
go func() { //nolint:gosec
select {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) //nolint:gosec
defer cancel()
if shutdownErr := srv.Shutdown(shutdownCtx); shutdownErr != nil {
log.WithError(shutdownErr).Error("debug server shutdown error")
}
case <-done:
}
}()

err = srv.Serve(lis)
close(done)
return err
}
158 changes: 158 additions & 0 deletions pkg/dmsghttp/debug_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Package dmsghttp_test pkg/dmsghttp/debug_test.go
package dmsghttp_test

import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/cipher"
"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/nettest"

"github.com/skycoin/dmsg/pkg/disc"
"github.com/skycoin/dmsg/pkg/dmsg"
"github.com/skycoin/dmsg/pkg/dmsghttp"
)

func TestDebugMux_PprofEndpoints(t *testing.T) {
mux := dmsghttp.DebugMux()

endpoints := []string{
"/debug/pprof/",
"/debug/pprof/cmdline",
"/debug/pprof/symbol",
"/debug/pprof/heap",
"/debug/pprof/goroutine",
"/debug/pprof/allocs",
}

for _, ep := range endpoints {
t.Run(ep, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, ep, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.NotEqual(t, http.StatusNotFound, w.Code, "endpoint %s should be registered", ep)
})
}
}

func TestWhitelistMiddleware_EmptyWhitelistAllowsAll(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := dmsghttp.WhitelistMiddleware(nil, inner)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "somekey:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}

func TestWhitelistMiddleware_AllowsWhitelistedPK(t *testing.T) {
pk, _ := cipher.GenerateKeyPair()

inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := dmsghttp.WhitelistMiddleware([]cipher.PubKey{pk}, inner)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = pk.String() + ":1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}

func TestWhitelistMiddleware_RejectsNonWhitelistedPK(t *testing.T) {
whitelisted, _ := cipher.GenerateKeyPair()
nonWhitelisted, _ := cipher.GenerateKeyPair()

inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := dmsghttp.WhitelistMiddleware([]cipher.PubKey{whitelisted}, inner)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = nonWhitelisted.String() + ":1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}

func TestWhitelistMiddleware_InvalidRemoteAddr(t *testing.T) {
pk, _ := cipher.GenerateKeyPair()

inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

handler := dmsghttp.WhitelistMiddleware([]cipher.PubKey{pk}, inner)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "invalid-no-port"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
}

func TestServeDebug_Integration(t *testing.T) {
// Integration test: start a dmsg server + client, serve debug, and fetch pprof index
dc := disc.NewMock(0)

// Create and start server
srvPK, srvSK := cipher.GenerateKeyPair()
srvConf := dmsg.ServerConfig{MaxSessions: 100}
srv := dmsg.NewServer(srvPK, srvSK, dc, &srvConf, nil)

lis, err := nettest.NewLocalListener("tcp")
require.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go srv.Serve(lis, "") //nolint:errcheck
<-srv.Ready()
defer srv.Close() //nolint:errcheck

// Create debug client that connects through the server
clientPK, clientSK := cipher.GenerateKeyPair()
dmsgC := dmsg.NewClient(clientPK, clientSK, dc, &dmsg.Config{MinSessions: 1})
go dmsgC.Serve(ctx) //nolint:errcheck
<-dmsgC.Ready()
defer dmsgC.Close() //nolint:errcheck

// Serve debug on the client
log := logging.MustGetLogger("test-debug")
go dmsghttp.ServeDebug(ctx, dmsgC, log, nil) //nolint:errcheck

// Allow listener to start
time.Sleep(100 * time.Millisecond)

// Create a second client to access the debug interface
fetchPK, fetchSK := cipher.GenerateKeyPair()
fetchC := dmsg.NewClient(fetchPK, fetchSK, dc, &dmsg.Config{MinSessions: 1})
go fetchC.Serve(ctx) //nolint:errcheck
<-fetchC.Ready()
defer fetchC.Close() //nolint:errcheck

httpC := http.Client{Transport: dmsghttp.MakeHTTPTransport(ctx, fetchC)}
resp, err := httpC.Get(fmt.Sprintf("dmsg://%s:%d/debug/pprof/", clientPK.Hex(), dmsghttp.DefaultDebugPort))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck

assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), "pprof")
}