11import enum
2+ from functools import singledispatchmethod
23import ipaddress
34import socket
4- from typing import Optional , Union
5+ from typing import Optional , Type , Union
56from dataclasses import dataclass , field
67
78from .errors import ReplyError
89from .._helpers import is_ip_address
910
10-
1111RSV = NULL = AUTH_GRANTED = 0x00
1212SOCKS_VER = 0x05
1313
@@ -302,34 +302,36 @@ class StateServerConnected:
302302
303303
304304class Connection :
305- _state : ConnectionState
306-
307305 def __init__ (self ):
308306 self ._state = StateServerWaitingForAuthMethods ()
309307
308+ @singledispatchmethod
310309 def send (self , request : Request ) -> bytes :
311- if type (request ) is AuthMethodsRequest :
312- if type (self ._state ) is not StateServerWaitingForAuthMethods :
313- raise RuntimeError ('Server is not currently waiting for auth methods' )
314- self ._state = StateClientSentAuthMethods (request )
315- return request .dumps ()
316-
317- if type (request ) is AuthRequest :
318- if type (self ._state ) is not StateServerWaitingForAuth :
319- raise RuntimeError ('Server is not currently waiting for authentication' )
320- self ._state = StateClientSentAuthRequest (request )
321- return request .dumps ()
322-
323- if type (request ) is ConnectRequest :
324- if type (self ._state ) is not StateClientAuthenticated :
325- raise RuntimeError ('Client is not authenticated' )
326- self ._state = StateClientSentConnectRequest (request )
327- return request .dumps ()
328-
329- raise RuntimeError (f'Invalid request type: { type (request )} ' )
310+ raise RuntimeError (f'Invalid request type: { request .__class__ } ' )
311+
312+ @send .register
313+ def _send_auth_methods (self , request : AuthMethodsRequest ) -> bytes :
314+ if not self ._state_is (StateServerWaitingForAuthMethods ):
315+ raise RuntimeError ('Server is not currently waiting for auth methods' )
316+ self ._state = StateClientSentAuthMethods (request )
317+ return request .dumps ()
318+
319+ @send .register
320+ def _send_auth (self , request : AuthRequest ) -> bytes :
321+ if not self ._state_is (StateServerWaitingForAuth ):
322+ raise RuntimeError ('Server is not currently waiting for authentication' )
323+ self ._state = StateClientSentAuthRequest (request )
324+ return request .dumps ()
325+
326+ @send .register
327+ def _send_connect (self , request : ConnectRequest ) -> bytes :
328+ if not self ._state_is (StateClientAuthenticated ):
329+ raise RuntimeError ('Client is not authenticated' )
330+ self ._state = StateClientSentConnectRequest (request )
331+ return request .dumps ()
330332
331333 def receive (self , data : bytes ) -> Reply :
332- if type ( self ._state ) is StateClientSentAuthMethods :
334+ if self ._state_is ( StateClientSentAuthMethods ) :
333335 reply = AuthMethodReply .loads (data )
334336 reply .validate (self ._state .data )
335337 if reply .method == AuthMethod .USERNAME_PASSWORD :
@@ -338,18 +340,21 @@ def receive(self, data: bytes) -> Reply:
338340 self ._state = StateClientAuthenticated ()
339341 return reply
340342
341- if type ( self ._state ) is StateClientSentAuthRequest :
343+ if self ._state_is ( StateClientSentAuthRequest ) :
342344 reply = AuthReply .loads (data )
343345 self ._state = StateClientAuthenticated (data = reply )
344346 return reply
345347
346- if type ( self ._state ) is StateClientSentConnectRequest :
348+ if self ._state_is ( StateClientSentConnectRequest ) :
347349 reply = ConnectReply .loads (data )
348350 self ._state = StateServerConnected (data = reply )
349351 return reply
350352
351353 raise RuntimeError (f'Invalid connection state: { self ._state } ' )
352354
355+ def _state_is (self , state_cls : Type [ConnectionState ]):
356+ return self .state .__class__ is state_cls
357+
353358 @property
354359 def state (self ):
355360 return self ._state
0 commit comments