@@ -7,13 +7,15 @@ import (
77 "net"
88 "runtime/debug"
99 "strings"
10+ "sync"
1011 "time"
1112
1213 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
1314 "github.com/sirupsen/logrus"
1415 "golang.org/x/net/proxy"
1516 "google.golang.org/grpc"
1617 "google.golang.org/grpc/codes"
18+ "google.golang.org/grpc/connectivity"
1719 "google.golang.org/grpc/credentials"
1820 "google.golang.org/grpc/credentials/insecure"
1921 "google.golang.org/grpc/keepalive"
@@ -30,73 +32,66 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3032 }
3133}
3234
33- // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
35+ // BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3436// and blocking until the returned connection is ready. If the given credentials are nil, the
3537// connection will be insecure (plain-text).
3638// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37- func BlockingDial (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
38- // grpc.Dial doesn't provide any information on permanent connection errors (like
39- // TLS handshake failures). So in order to provide good error messages, we need a
40- // custom dialer that can provide that info. That means we manage the TLS handshake.
41- result := make (chan any , 1 )
42- writeResult := func (res any ) {
43- // non-blocking write: we only need the first result
44- select {
45- case result <- res :
46- default :
47- }
39+ func BlockingNewClient (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
40+ proxyDialer := proxy .FromEnvironment ()
41+ rawConn , err := proxyDialer .Dial (network , address )
42+ if err != nil {
43+ return nil , fmt .Errorf ("proxy dial failed: %w" , err )
4844 }
4945
50- dialer := func (ctx context.Context , address string ) (net.Conn , error ) {
51- proxyDialer := proxy .FromEnvironment ()
52- conn , err := proxyDialer .Dial (network , address )
46+ if creds != nil {
47+ rawConn , _ , err = creds .ClientHandshake (ctx , address , rawConn )
5348 if err != nil {
54- writeResult (err )
55- return nil , fmt .Errorf ("error dial proxy: %w" , err )
49+ return nil , fmt .Errorf ("TLS handshake failed: %w" , err )
5650 }
57- if creds != nil {
58- conn , _ , err = creds .ClientHandshake (ctx , address , conn )
59- if err != nil {
60- writeResult (err )
61- return nil , fmt .Errorf ("error creating connection: %w" , err )
62- }
51+ }
52+
53+ var once sync.Once
54+ connUsed := false
55+ oneShot := func (ctx context.Context , target string ) (net.Conn , error ) {
56+ var conn net.Conn
57+ once .Do (func () {
58+ conn = rawConn
59+ connUsed = true
60+ })
61+ if ! connUsed {
62+ return nil , fmt .Errorf ("connection already consumed" )
6363 }
6464 return conn , nil
6565 }
6666
67- // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
68- // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
69- // know when we're done. So we run it in a goroutine and then use result
70- // channel to either get the channel or fail-fast.
71- go func () {
72- opts = append (opts ,
73- //nolint:staticcheck
74- grpc .WithBlock (),
75- //nolint:staticcheck
76- grpc .FailOnNonTempDialError (true ),
77- grpc .WithContextDialer (dialer ),
78- grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
79- grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
80- )
81- //nolint:staticcheck
82- conn , err := grpc .DialContext (ctx , address , opts ... )
83- var res any
84- if err != nil {
85- res = err
86- } else {
87- res = conn
88- }
89- writeResult (res )
90- }()
67+ opts = append (opts ,
68+ grpc .WithContextDialer (oneShot ),
69+ grpc .WithTransportCredentials (insecure .NewCredentials ()),
70+ grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
71+ )
9172
92- select {
93- case res := <- result :
94- if conn , ok := res .(* grpc.ClientConn ); ok {
95- return conn , nil
73+ cc , err := grpc .NewClient ("passthrough:" + address , opts ... )
74+ if err != nil {
75+ return nil , fmt .Errorf ("grpc.NewClient failed: %w" , err )
76+ }
77+
78+ cc .Connect ()
79+ if err := waitForReady (ctx , cc ); err != nil {
80+ return nil , fmt .Errorf ("gRPC connection not ready: %w" , err )
81+ }
82+
83+ return cc , nil
84+ }
85+
86+ func waitForReady (ctx context.Context , conn * grpc.ClientConn ) error {
87+ for {
88+ state := conn .GetState ()
89+ if state == connectivity .Ready {
90+ return nil
91+ }
92+ if ! conn .WaitForStateChange (ctx , state ) {
93+ return ctx .Err () // context timeout or cancellation
9694 }
97- return nil , res .(error )
98- case <- ctx .Done ():
99- return nil , ctx .Err ()
10095 }
10196}
10297
@@ -120,15 +115,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
120115 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
121116 defer cancel ()
122117
123- conn , err := BlockingDial (ctx , "tcp" , address , creds )
118+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
124119 if err == nil {
125120 _ = conn .Close ()
126121 testResult .TLS = true
127122 creds := credentials .NewTLS (& tls.Config {})
128123 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
129124 defer cancel ()
130125
131- conn , err := BlockingDial (ctx , "tcp" , address , creds )
126+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
132127 if err == nil {
133128 _ = conn .Close ()
134129 } else {
@@ -143,7 +138,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
143138 // refused). Test if server accepts plain-text connections
144139 ctx , cancel = context .WithTimeout (context .Background (), dialTime )
145140 defer cancel ()
146- conn , err = BlockingDial (ctx , "tcp" , address , nil )
141+ conn , err = BlockingNewClient (ctx , "tcp" , address , nil )
147142 if err == nil {
148143 _ = conn .Close ()
149144 testResult .TLS = false
0 commit comments