Skip to content

Commit 2a47219

Browse files
committed
socks5 protocol refactoring
1 parent 8794dfc commit 2a47219

1 file changed

Lines changed: 31 additions & 26 deletions

File tree

python_socks/_protocols/socks5.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import enum
2+
from functools import singledispatchmethod
23
import ipaddress
34
import socket
4-
from typing import Optional, Union
5+
from typing import Optional, Type, Union
56
from dataclasses import dataclass, field
67

78
from .errors import ReplyError
89
from .._helpers import is_ip_address
910

10-
1111
RSV = NULL = AUTH_GRANTED = 0x00
1212
SOCKS_VER = 0x05
1313

@@ -302,34 +302,36 @@ class StateServerConnected:
302302

303303

304304
class Connection:
305-
_state: ConnectionState
306-
307305
def __init__(self):
308306
self._state = StateServerWaitingForAuthMethods()
309307

308+
@singledispatchmethod
310309
def send(self, request: Request) -> bytes:
311-
if type(request) is AuthMethodsRequest:
312-
if type(self._state) is not StateServerWaitingForAuthMethods:
313-
raise RuntimeError('Server is not currently waiting for auth methods')
314-
self._state = StateClientSentAuthMethods(request)
315-
return request.dumps()
316-
317-
if type(request) is AuthRequest:
318-
if type(self._state) is not StateServerWaitingForAuth:
319-
raise RuntimeError('Server is not currently waiting for authentication')
320-
self._state = StateClientSentAuthRequest(request)
321-
return request.dumps()
322-
323-
if type(request) is ConnectRequest:
324-
if type(self._state) is not StateClientAuthenticated:
325-
raise RuntimeError('Client is not authenticated')
326-
self._state = StateClientSentConnectRequest(request)
327-
return request.dumps()
328-
329-
raise RuntimeError(f'Invalid request type: {type(request)}')
310+
raise RuntimeError(f'Invalid request type: {request.__class__}')
311+
312+
@send.register
313+
def _send_auth_methods(self, request: AuthMethodsRequest) -> bytes:
314+
if not self._state_is(StateServerWaitingForAuthMethods):
315+
raise RuntimeError('Server is not currently waiting for auth methods')
316+
self._state = StateClientSentAuthMethods(request)
317+
return request.dumps()
318+
319+
@send.register
320+
def _send_auth(self, request: AuthRequest) -> bytes:
321+
if not self._state_is(StateServerWaitingForAuth):
322+
raise RuntimeError('Server is not currently waiting for authentication')
323+
self._state = StateClientSentAuthRequest(request)
324+
return request.dumps()
325+
326+
@send.register
327+
def _send_connect(self, request: ConnectRequest) -> bytes:
328+
if not self._state_is(StateClientAuthenticated):
329+
raise RuntimeError('Client is not authenticated')
330+
self._state = StateClientSentConnectRequest(request)
331+
return request.dumps()
330332

331333
def receive(self, data: bytes) -> Reply:
332-
if type(self._state) is StateClientSentAuthMethods:
334+
if self._state_is(StateClientSentAuthMethods):
333335
reply = AuthMethodReply.loads(data)
334336
reply.validate(self._state.data)
335337
if reply.method == AuthMethod.USERNAME_PASSWORD:
@@ -338,18 +340,21 @@ def receive(self, data: bytes) -> Reply:
338340
self._state = StateClientAuthenticated()
339341
return reply
340342

341-
if type(self._state) is StateClientSentAuthRequest:
343+
if self._state_is(StateClientSentAuthRequest):
342344
reply = AuthReply.loads(data)
343345
self._state = StateClientAuthenticated(data=reply)
344346
return reply
345347

346-
if type(self._state) is StateClientSentConnectRequest:
348+
if self._state_is(StateClientSentConnectRequest):
347349
reply = ConnectReply.loads(data)
348350
self._state = StateServerConnected(data=reply)
349351
return reply
350352

351353
raise RuntimeError(f'Invalid connection state: {self._state}')
352354

355+
def _state_is(self, state_cls: Type[ConnectionState]):
356+
return self.state.__class__ is state_cls
357+
353358
@property
354359
def state(self):
355360
return self._state

0 commit comments

Comments
 (0)