Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions openhexa/sdk/workspaces/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
class Connection:
"""Abstract base class for connections."""

pass
_: dataclasses.KW_ONLY
identifier: str = ""


@dataclasses.dataclass
Expand Down Expand Up @@ -100,7 +101,7 @@ def __repr__(self):


@dataclasses.dataclass
class IASOConnection:
class IASOConnection(Connection):
"""IASO connection.

See https://github.com/BLSQ/iaso for more information.
Expand Down
11 changes: 6 additions & 5 deletions openhexa/sdk/workspaces/current_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def _get_local_connection_fields(self, env_variable_prefix: str):
connection_fields = {}
connection_type = os.getenv(env_variable_prefix).upper()

# Get fields for the connection type
_fields = fields(ConnectionClasses[connection_type])
# Get fields for the connection type, excluding base Connection fields
_fields = [f for f in fields(ConnectionClasses[connection_type]) if f.name != "identifier"]

if _fields:
for field in _fields:
Expand Down Expand Up @@ -305,14 +305,15 @@ def get_connection(
# different from the offline ones
if connection_type == "S3":
secret_access_key = connection_fields.pop("access_key_secret")
return S3Connection(secret_access_key=secret_access_key, **connection_fields)
return S3Connection(secret_access_key=secret_access_key, identifier=identifier, **connection_fields)

if connection_type == "POSTGRESQL":
db_name = connection_fields.pop("db_name")
port = int(connection_fields.pop("port"))
return PostgreSQLConnection(
database_name=db_name,
port=port,
identifier=identifier,
**connection_fields,
)

Expand All @@ -323,9 +324,9 @@ def get_connection(
bases=(CustomConnection,),
repr=False,
)
return dataclass(**connection_fields)
return dataclass(identifier=identifier, **connection_fields)

return ConnectionClasses[connection_type](**connection_fields)
return ConnectionClasses[connection_type](identifier=identifier, **connection_fields)

def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection:
"""Get a DHIS2 connection by its identifier.
Expand Down
Loading