@@ -14,6 +14,7 @@ import (
1414 "golang.org/x/net/proxy"
1515 "google.golang.org/grpc"
1616 "google.golang.org/grpc/codes"
17+ "google.golang.org/grpc/connectivity"
1718 "google.golang.org/grpc/credentials"
1819 "google.golang.org/grpc/credentials/insecure"
1920 "google.golang.org/grpc/keepalive"
@@ -30,61 +31,54 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3031 }
3132}
3233
33- // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
34+ // BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3435// and blocking until the returned connection is ready. If the given credentials are nil, the
3536// connection will be insecure (plain-text).
3637// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37- func BlockingDial (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
38- // grpc.Dial doesn't provide any information on permanent connection errors (like
39- // TLS handshake failures). So in order to provide good error messages, we need a
40- // custom dialer that can provide that info. That means we manage the TLS handshake.
38+ func BlockingNewClient (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
4139 result := make (chan any , 1 )
4240 writeResult := func (res any ) {
43- // non-blocking write: we only need the first result
4441 select {
4542 case result <- res :
4643 default :
4744 }
4845 }
4946
50- dialer := func (ctx context.Context , address string ) (net.Conn , error ) {
47+ customDialer := func (ctx context.Context , address string ) (net.Conn , error ) {
5148 proxyDialer := proxy .FromEnvironment ()
5249 conn , err := proxyDialer .Dial (network , address )
5350 if err != nil {
54- writeResult (err )
5551 return nil , fmt .Errorf ("error dial proxy: %w" , err )
5652 }
53+
5754 if creds != nil {
5855 conn , _ , err = creds .ClientHandshake (ctx , address , conn )
5956 if err != nil {
60- writeResult (err )
6157 return nil , fmt .Errorf ("error creating connection: %w" , err )
6258 }
6359 }
60+
6461 return conn , nil
6562 }
6663
67- // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
68- // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
69- // know when we're done. So we run it in a goroutine and then use result
70- // channel to either get the channel or fail-fast.
7164 go func () {
7265 opts = append (opts ,
73- //nolint:staticcheck
74- grpc .WithBlock (),
75- //nolint:staticcheck
76- grpc .FailOnNonTempDialError (true ),
77- grpc .WithContextDialer (dialer ),
78- grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
66+ grpc .WithContextDialer (customDialer ),
67+ grpc .WithTransportCredentials (insecure .NewCredentials ()),
7968 grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
8069 )
81- //nolint:staticcheck
82- conn , err := grpc .DialContext ( ctx , address , opts ... )
70+
71+ conn , err := grpc .NewClient ( "passthrough:" + address , opts ... )
8372 var res any
8473 if err != nil {
8574 res = err
8675 } else {
87- res = conn
76+ conn .Connect ()
77+ if err := waitForReady (ctx , conn ); err != nil {
78+ res = err
79+ } else {
80+ res = conn
81+ }
8882 }
8983 writeResult (res )
9084 }()
@@ -100,6 +94,18 @@ func BlockingDial(ctx context.Context, network, address string, creds credential
10094 }
10195}
10296
97+ func waitForReady (ctx context.Context , conn * grpc.ClientConn ) error {
98+ for {
99+ state := conn .GetState ()
100+ if state == connectivity .Ready {
101+ return nil
102+ }
103+ if ! conn .WaitForStateChange (ctx , state ) {
104+ return ctx .Err () // context timeout or cancellation
105+ }
106+ }
107+ }
108+
103109type TLSTestResult struct {
104110 TLS bool
105111 InsecureErr error
@@ -120,15 +126,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
120126 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
121127 defer cancel ()
122128
123- conn , err := BlockingDial (ctx , "tcp" , address , creds )
129+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
124130 if err == nil {
125131 _ = conn .Close ()
126132 testResult .TLS = true
127133 creds := credentials .NewTLS (& tls.Config {})
128134 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
129135 defer cancel ()
130136
131- conn , err := BlockingDial (ctx , "tcp" , address , creds )
137+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
132138 if err == nil {
133139 _ = conn .Close ()
134140 } else {
@@ -143,7 +149,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
143149 // refused). Test if server accepts plain-text connections
144150 ctx , cancel = context .WithTimeout (context .Background (), dialTime )
145151 defer cancel ()
146- conn , err = BlockingDial (ctx , "tcp" , address , nil )
152+ conn , err = BlockingNewClient (ctx , "tcp" , address , nil )
147153 if err == nil {
148154 _ = conn .Close ()
149155 testResult .TLS = false
0 commit comments