diff --git a/openhexa/sdk/workspaces/connection.py b/openhexa/sdk/workspaces/connection.py index e7c47a8e..e51e86d1 100644 --- a/openhexa/sdk/workspaces/connection.py +++ b/openhexa/sdk/workspaces/connection.py @@ -7,7 +7,8 @@ class Connection: """Abstract base class for connections.""" - pass + _: dataclasses.KW_ONLY + identifier: str = "" @dataclasses.dataclass @@ -100,7 +101,7 @@ def __repr__(self): @dataclasses.dataclass -class IASOConnection: +class IASOConnection(Connection): """IASO connection. See https://github.com/BLSQ/iaso for more information. diff --git a/openhexa/sdk/workspaces/current_workspace.py b/openhexa/sdk/workspaces/current_workspace.py index 3da3d893..20d9834f 100644 --- a/openhexa/sdk/workspaces/current_workspace.py +++ b/openhexa/sdk/workspaces/current_workspace.py @@ -224,8 +224,8 @@ def _get_local_connection_fields(self, env_variable_prefix: str): connection_fields = {} connection_type = os.getenv(env_variable_prefix).upper() - # Get fields for the connection type - _fields = fields(ConnectionClasses[connection_type]) + # Get fields for the connection type, excluding base Connection fields + _fields = [f for f in fields(ConnectionClasses[connection_type]) if f.name != "identifier"] if _fields: for field in _fields: @@ -305,7 +305,7 @@ def get_connection( # different from the offline ones if connection_type == "S3": secret_access_key = connection_fields.pop("access_key_secret") - return S3Connection(secret_access_key=secret_access_key, **connection_fields) + return S3Connection(secret_access_key=secret_access_key, identifier=identifier, **connection_fields) if connection_type == "POSTGRESQL": db_name = connection_fields.pop("db_name") @@ -313,6 +313,7 @@ def get_connection( return PostgreSQLConnection( database_name=db_name, port=port, + identifier=identifier, **connection_fields, ) @@ -323,9 +324,9 @@ def get_connection( bases=(CustomConnection,), repr=False, ) - return dataclass(**connection_fields) + return dataclass(identifier=identifier, **connection_fields) - return ConnectionClasses[connection_type](**connection_fields) + return ConnectionClasses[connection_type](identifier=identifier, **connection_fields) def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection: """Get a DHIS2 connection by its identifier.