Skip to content

Commit 8c445cf

Browse files
fix: replace grpc.NewClient
Signed-off-by: Jack-R-lantern <tjdfkr2421@gmail.com>
1 parent bfe8b30 commit 8c445cf

8 files changed

Lines changed: 60 additions & 68 deletions

File tree

cmd/argocd-git-ask-pass/commands/argocd_git_ask_pass.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func NewCommand() *cobra.Command {
3535
if nonce == "" {
3636
errors.CheckError(fmt.Errorf("%s is not set", askpass.ASKPASS_NONCE_ENV))
3737
}
38-
conn, err := grpc_util.BlockingDial(ctx, "unix", askpass.SocketPath, nil, grpc.WithTransportCredentials(insecure.NewCredentials()))
38+
conn, err := grpc_util.BlockingNewClient(ctx, "unix", askpass.SocketPath, nil, grpc.WithTransportCredentials(insecure.NewCredentials()))
3939
errors.CheckError(err)
4040
defer utilio.Close(conn)
4141
client := askpass.NewAskPassServiceClient(conn)

cmpserver/apiclient/clientset.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func NewConnection(address string) (*grpc.ClientConn, error) {
5252
}
5353

5454
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
55-
conn, err := grpc_util.BlockingDial(context.Background(), "unix", address, nil, dialOpts...)
55+
conn, err := grpc_util.BlockingNewClient(context.Background(), "unix", address, nil, dialOpts...)
5656
if err != nil {
5757
log.Errorf("Unable to connect to config management plugin service with address %s", address)
5858
return nil, err

commitserver/apiclient/clientset.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ func NewConnection(address string) (*grpc.ClientConn, error) {
4040
var opts []grpc.DialOption
4141
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
4242

43-
// TODO: switch to grpc.NewClient.
44-
//nolint:staticcheck
45-
conn, err := grpc.Dial(address, opts...)
43+
conn, err := grpc.NewClient(address, opts...)
4644
if err != nil {
4745
log.Errorf("Unable to connect to commit service with address %s", address)
4846
return nil, err

pkg/apiclient/apiclient.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ func (c *client) newConn() (*grpc.ClientConn, io.Closer, error) {
542542
if c.UserAgent != "" {
543543
dialOpts = append(dialOpts, grpc.WithUserAgent(c.UserAgent))
544544
}
545-
conn, e := grpc_util.BlockingDial(ctx, network, serverAddr, creds, dialOpts...)
545+
conn, e := grpc_util.BlockingNewClient(ctx, network, serverAddr, creds, dialOpts...)
546546
closers = append(closers, conn)
547547
return conn, utilio.NewCloser(func() error {
548548
var firstErr error

reposerver/apiclient/clientset.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ func NewConnection(address string, timeoutSeconds int, tlsConfig *TLSConfigurati
8282
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
8383
}
8484

85-
//nolint:staticcheck
86-
conn, err := grpc.Dial(address, opts...)
85+
conn, err := grpc.NewClient(address, opts...)
8786
if err != nil {
8887
log.Errorf("Unable to connect to repository service with address %s", address)
8988
return nil, err

server/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ func (server *ArgoCDServer) Listen() (*Listeners, error) {
530530
} else {
531531
dOpts = append(dOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
532532
}
533-
//nolint:staticcheck
534-
conn, err := grpc.Dial(fmt.Sprintf("localhost:%d", server.ListenPort), dOpts...)
533+
534+
conn, err := grpc.NewClient(fmt.Sprintf("localhost:%d", server.ListenPort), dOpts...)
535535
if err != nil {
536536
utilio.Close(mainLn)
537537
utilio.Close(metricsLn)

util/grpc/grpc.go

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import (
77
"net"
88
"runtime/debug"
99
"strings"
10+
"sync"
1011
"time"
1112

1213
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
1314
"github.com/sirupsen/logrus"
1415
"golang.org/x/net/proxy"
1516
"google.golang.org/grpc"
1617
"google.golang.org/grpc/codes"
18+
"google.golang.org/grpc/connectivity"
1719
"google.golang.org/grpc/credentials"
1820
"google.golang.org/grpc/credentials/insecure"
1921
"google.golang.org/grpc/keepalive"
@@ -30,73 +32,66 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3032
}
3133
}
3234

33-
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
35+
// BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3436
// and blocking until the returned connection is ready. If the given credentials are nil, the
3537
// connection will be insecure (plain-text).
3638
// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37-
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
38-
// grpc.Dial doesn't provide any information on permanent connection errors (like
39-
// TLS handshake failures). So in order to provide good error messages, we need a
40-
// custom dialer that can provide that info. That means we manage the TLS handshake.
41-
result := make(chan any, 1)
42-
writeResult := func(res any) {
43-
// non-blocking write: we only need the first result
44-
select {
45-
case result <- res:
46-
default:
47-
}
39+
func BlockingNewClient(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
40+
proxyDialer := proxy.FromEnvironment()
41+
rawConn, err := proxyDialer.Dial(network, address)
42+
if err != nil {
43+
return nil, fmt.Errorf("proxy dial failed: %w", err)
4844
}
4945

50-
dialer := func(ctx context.Context, address string) (net.Conn, error) {
51-
proxyDialer := proxy.FromEnvironment()
52-
conn, err := proxyDialer.Dial(network, address)
46+
if creds != nil {
47+
rawConn, _, err = creds.ClientHandshake(ctx, address, rawConn)
5348
if err != nil {
54-
writeResult(err)
55-
return nil, fmt.Errorf("error dial proxy: %w", err)
49+
return nil, fmt.Errorf("TLS handshake failed: %w", err)
5650
}
57-
if creds != nil {
58-
conn, _, err = creds.ClientHandshake(ctx, address, conn)
59-
if err != nil {
60-
writeResult(err)
61-
return nil, fmt.Errorf("error creating connection: %w", err)
62-
}
51+
}
52+
53+
var once sync.Once
54+
connUsed := false
55+
oneShot := func(ctx context.Context, target string) (net.Conn, error) {
56+
var conn net.Conn
57+
once.Do(func() {
58+
conn = rawConn
59+
connUsed = true
60+
})
61+
if !connUsed {
62+
return nil, fmt.Errorf("connection already consumed")
6363
}
6464
return conn, nil
6565
}
6666

67-
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
68-
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
69-
// know when we're done. So we run it in a goroutine and then use result
70-
// channel to either get the channel or fail-fast.
71-
go func() {
72-
opts = append(opts,
73-
//nolint:staticcheck
74-
grpc.WithBlock(),
75-
//nolint:staticcheck
76-
grpc.FailOnNonTempDialError(true),
77-
grpc.WithContextDialer(dialer),
78-
grpc.WithTransportCredentials(insecure.NewCredentials()), // we are handling TLS, so tell grpc not to
79-
grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: common.GetGRPCKeepAliveTime()}),
80-
)
81-
//nolint:staticcheck
82-
conn, err := grpc.DialContext(ctx, address, opts...)
83-
var res any
84-
if err != nil {
85-
res = err
86-
} else {
87-
res = conn
88-
}
89-
writeResult(res)
90-
}()
67+
opts = append(opts,
68+
grpc.WithContextDialer(oneShot),
69+
grpc.WithTransportCredentials(insecure.NewCredentials()),
70+
grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: common.GetGRPCKeepAliveTime()}),
71+
)
9172

92-
select {
93-
case res := <-result:
94-
if conn, ok := res.(*grpc.ClientConn); ok {
95-
return conn, nil
73+
cc, err := grpc.NewClient("passthrough:"+address, opts...)
74+
if err != nil {
75+
return nil, fmt.Errorf("grpc.NewClient failed: %w", err)
76+
}
77+
78+
cc.Connect()
79+
if err := waitForReady(ctx, cc); err != nil {
80+
return nil, fmt.Errorf("gRPC connection not ready: %w", err)
81+
}
82+
83+
return cc, nil
84+
}
85+
86+
func waitForReady(ctx context.Context, conn *grpc.ClientConn) error {
87+
for {
88+
state := conn.GetState()
89+
if state == connectivity.Ready {
90+
return nil
91+
}
92+
if !conn.WaitForStateChange(ctx, state) {
93+
return ctx.Err() // context timeout or cancellation
9694
}
97-
return nil, res.(error)
98-
case <-ctx.Done():
99-
return nil, ctx.Err()
10095
}
10196
}
10297

@@ -120,15 +115,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
120115
ctx, cancel := context.WithTimeout(context.Background(), dialTime)
121116
defer cancel()
122117

123-
conn, err := BlockingDial(ctx, "tcp", address, creds)
118+
conn, err := BlockingNewClient(ctx, "tcp", address, creds)
124119
if err == nil {
125120
_ = conn.Close()
126121
testResult.TLS = true
127122
creds := credentials.NewTLS(&tls.Config{})
128123
ctx, cancel := context.WithTimeout(context.Background(), dialTime)
129124
defer cancel()
130125

131-
conn, err := BlockingDial(ctx, "tcp", address, creds)
126+
conn, err := BlockingNewClient(ctx, "tcp", address, creds)
132127
if err == nil {
133128
_ = conn.Close()
134129
} else {
@@ -143,7 +138,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
143138
// refused). Test if server accepts plain-text connections
144139
ctx, cancel = context.WithTimeout(context.Background(), dialTime)
145140
defer cancel()
146-
conn, err = BlockingDial(ctx, "tcp", address, nil)
141+
conn, err = BlockingNewClient(ctx, "tcp", address, nil)
147142
if err == nil {
148143
_ = conn.Close()
149144
testResult.TLS = false

util/grpc/grpc_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func TestBlockingDial_ProxyEnvironmentHandling(t *testing.T) {
9393
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
9494
defer cancel()
9595

96-
conn, err := BlockingDial(ctx, "tcp", tt.address, nil)
96+
conn, err := BlockingNewClient(ctx, "tcp", tt.address, nil)
9797

9898
if tt.expectError {
9999
require.Error(t, err)

0 commit comments

Comments
 (0)