Skip to content

Commit b2e018c

Browse files
committed
feat: api improved following code review
1 parent 3bec7fb commit b2e018c

File tree

2 files changed

+100
-34
lines changed

2 files changed

+100
-34
lines changed

datashield/api.py

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DataSHIELD API.
33
"""
44

5+
import logging
56
from datashield.interface import DSLoginInfo, DSConnection, DSDriver, DSError
67
import time
78

@@ -72,13 +73,17 @@ class DSSession:
7273
DataSHIELD session, establishes connections with remote servers and performs commands.
7374
"""
7475

75-
def __init__(self, logins: list[DSLoginInfo]):
76+
def __init__(self, logins: list[DSLoginInfo], start_timeout: float = 300.0, start_delay: float = 0.1):
7677
"""
7778
Create a session, with connection information. Does not open the connections.
7879
7980
:param logins: A list of login details
81+
:param start_timeout: The maximum time in seconds to wait for R sessions to start, default is 300 seconds (5 minutes)
82+
:param start_delay: The delay in seconds between checking if R sessions are started, default is 0.1 seconds
8083
"""
8184
self.logins = logins
85+
self.start_timeout = start_timeout
86+
self.start_delay = start_delay
8287
self.conns: list[DSConnection] = None
8388
self.errors: dict = None
8489

@@ -234,61 +239,108 @@ def workspaces(self) -> dict:
234239
rval[conn.name] = conn.list_workspaces()
235240
return rval
236241

237-
def workspace_save(self, name: str) -> dict:
242+
def workspace_save(self, name: str) -> None:
238243
"""
239244
Save the DataSHIELD R session in a workspace on the remote data repository.
240245
241246
:param name: The name of the workspace
242-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after saving the workspace, per remote server name
243247
"""
244248
for conn in self.conns:
245249
conn.save_workspace(f"{conn.name}:{name}")
246-
return self.workspaces()
247250

248-
def workspace_restore(self, name: str) -> dict:
251+
def workspace_restore(self, name: str) -> None:
249252
"""
250253
Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace,
251254
any existing symbol or file with same name will be overridden.
252255
253256
:param name: The name of the workspace
254-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after restoring the workspace, per remote server name
255257
"""
256258
for conn in self.conns:
257259
conn.restore_workspace(f"{conn.name}:{name}")
258-
return self.workspaces()
259260

260-
def workspace_rm(self, name: str) -> dict:
261+
def workspace_rm(self, name: str) -> None:
261262
"""
262263
Remove a DataSHIELD workspace from the remote data repository. Ignored if no
263264
such workspace exists.
264265
265266
:param name: The name of the workspace
266-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after removing the workspace, per remote server name
267267
"""
268268
for conn in self.conns:
269269
conn.rm_workspace(f"{conn.name}:{name}")
270-
return self.workspaces()
271270

272271
#
273272
# R session
274273
#
275274

276275
def sessions(self) -> dict:
277276
"""
278-
Ensure R sessions are started on the remote servers and get their information.
279-
280-
:return: The R session information, per remote server name
277+
Ensure R sessions are started on the remote servers and wait until they are ready.
278+
This method returns a dictionary mapping each remote server name to its underlying
279+
R session object (an ``RSession`` instance). These session objects are
280+
primarily intended for status inspection (e.g. ``is_started()``, ``is_ready()``,
281+
``is_pending()``, ``is_failed()``, ``is_terminated()``, ``get_last_message()``) and
282+
not for direct interaction with the remote R environment.
283+
In normal use, you do not need to work with the returned session objects directly.
284+
Instead, interact with the remote R sessions through the higher-level ``DSSession``
285+
methods (such as assignment, aggregation, workspace and other helpers), which
286+
operate on all underlying sessions. Important: only sessions that have been successfully started
287+
and are ready will be included in the returned dictionary and used for subsequent operations.
288+
If a session fails to start or check status, it will be excluded from the returned dictionary
289+
and from subsequent operations, and an error will be logged. If no sessions can be started successfully,
290+
an exception will be raised.
291+
292+
:return: A dictionary mapping remote server names to their corresponding R session
293+
objects, intended mainly for internal use and status monitoring.
281294
"""
282295
rval = {}
296+
self._init_errors()
297+
started_conns = []
298+
excluded_conns = []
299+
300+
# start sessions asynchronously if supported, otherwise synchronously
283301
for conn in self.conns:
284-
if not conn.has_session():
285-
conn.start_session(asynchronous=True)
286-
# check for session status and wait until all are complete
287-
while any(conn.get_session().is_pending() for conn in self.conns):
288-
time.sleep(0.1)
302+
try:
303+
if not conn.has_session():
304+
conn.start_session(asynchronous=True)
305+
except Exception as e:
306+
logging.warning(f"Failed to start session: {conn.name} - {e}")
307+
excluded_conns.append(conn.name)
308+
309+
# check for session status and wait until all are started
310+
for conn in [c for c in self.conns if c.name not in excluded_conns]:
311+
try:
312+
if conn.is_session_started():
313+
started_conns.append(conn.name)
314+
except Exception as e:
315+
logging.warning(f"Failed to check session status: {conn.name} - {e}")
316+
excluded_conns.append(conn.name)
317+
318+
# wait until all sessions are started, excluding those that have failed to start or check status
319+
start_time = time.time()
320+
while len(started_conns) < len(self.conns) - len(excluded_conns):
321+
if time.time() - start_time > self.start_timeout:
322+
raise DSError("Timed out waiting for R sessions to start")
323+
time.sleep(self.start_delay)
324+
remaining_conns = [
325+
conn for conn in self.conns if conn.name not in started_conns and conn.name not in excluded_conns
326+
]
327+
for conn in remaining_conns:
328+
try:
329+
if conn.is_session_started():
330+
started_conns.append(conn.name)
331+
except Exception as e:
332+
logging.warning(f"Failed to check session status: {conn.name} - {e}")
333+
excluded_conns.append(conn.name)
334+
335+
# at this point, all sessions that could be started have been started, and those that failed to start or check status have been excluded
289336
for conn in self.conns:
290-
rval[conn.name] = conn.get_session()
291-
self._check_errors()
337+
if conn.name in started_conns:
338+
rval[conn.name] = conn.get_session()
339+
if len(excluded_conns) > 0:
340+
logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}")
341+
self.conns = [conn for conn in self.conns if conn.name not in excluded_conns]
342+
if len(self.conns) == 0:
343+
raise DSError("No sessions could be started successfully.")
292344
return rval
293345

294346
def ls(self) -> dict:
@@ -297,8 +349,8 @@ def ls(self) -> dict:
297349
298350
:return: The symbols that live in the DataSHIELD R session on the server side, per remote server name
299351
"""
300-
self._init_errors()
301-
self.sessions() # ensure sessions are started and available
352+
# ensure sessions are started and available
353+
self.sessions()
302354
rval = {}
303355
for conn in self.conns:
304356
try:
@@ -315,8 +367,8 @@ def rm(self, symbol: str) -> None:
315367
316368
:param symbol: The name of the symbol to remove
317369
"""
318-
self._init_errors()
319-
self.sessions() # ensure sessions are started and available
370+
# ensure sessions are started and available
371+
self.sessions()
320372
for conn in self.conns:
321373
try:
322374
conn.rm_symbol(symbol)
@@ -343,8 +395,8 @@ def assign_table(
343395
:param tables: The name of the table to assign, per server name. If not defined, 'table' is used.
344396
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
345397
"""
346-
self._init_errors()
347-
self.sessions() # ensure sessions are started and available
398+
# ensure sessions are started and available
399+
self.sessions()
348400
cmd = {}
349401
for conn in self.conns:
350402
name = table
@@ -370,8 +422,8 @@ def assign_resource(
370422
:param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used.
371423
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
372424
"""
373-
self._init_errors()
374-
self.sessions() # ensure sessions are started and available
425+
# ensure sessions are started and available
426+
self.sessions()
375427
cmd = {}
376428
for conn in self.conns:
377429
name = resource
@@ -394,8 +446,8 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None
394446
:param expr: The R expression to evaluate and which result will be assigned
395447
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
396448
"""
397-
self._init_errors()
398-
self.sessions() # ensure sessions are started and available
449+
# ensure sessions are started and available
450+
self.sessions()
399451
cmd = {}
400452
for conn in self.conns:
401453
try:
@@ -415,8 +467,8 @@ def aggregate(self, expr: str, asynchronous: bool = True) -> dict:
415467
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
416468
:return: The result of the aggregation expression evaluation, per remote server name
417469
"""
418-
self._init_errors()
419-
self.sessions() # ensure sessions are started and available
470+
# ensure sessions are started and available
471+
self.sessions()
420472
cmd = {}
421473
rval = {}
422474
for conn in self.conns:
@@ -465,6 +517,7 @@ def _append_error(self, conn: DSConnection, error: Exception) -> None:
465517
"""
466518
Append an error.
467519
"""
520+
logging.error(f"[{conn.name}] {error}")
468521
self.errors[conn.name] = error
469522

470523
def _check_errors(self) -> None:

datashield/interface.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def start_session(self, asynchronous: bool = True) -> RSession:
187187
"""
188188
raise NotImplementedError("DSConnection function not available")
189189

190+
def is_session_started(self) -> bool:
191+
"""
192+
Get whether the session with the DataSHIELD server is started. If the session start was asynchronous, this function
193+
can be used to check whether the session is started without waiting for it to be started. If the last call was positive,
194+
subsequent calls will not request the server for session status, but will return True directly. If the last call was negative,
195+
subsequent calls will request the server for session status until a positive response is obtained.
196+
197+
:return: Whether the session is started
198+
:throws: DSError if the session was not started or session information is not available
199+
"""
200+
raise NotImplementedError("DSConnection function not available")
201+
190202
def get_session(self) -> RSession:
191203
"""
192204
Get the R session with the DataSHIELD server. If no session is established, an error will be raised.
@@ -359,9 +371,10 @@ def is_async(self) -> dict:
359371
the raw result can be accessed asynchronously, allowing parallelization of DataSHIELD calls
360372
over multpile servers. The returned named list of logicals will specify if asynchronicity is supported for:
361373
aggregation operation ('aggregate'), table assignment operation ('assign_table'),
362-
resource assignment operation ('assign_resource') and expression assignment operation ('assign_expr').
374+
resource assignment operation ('assign_resource'), expression assignment operation ('assign_expr')
375+
and R session creation ('session').
363376
364-
:return: A named list of logicals specifying if asynchronicity is supported for aggregation operation ('aggregate'), table assignment operation ('assign_table'),
377+
:return: A named list of logicals specifying if asynchronicity is supported.
365378
"""
366379
raise NotImplementedError("DSConnection function not available")
367380

0 commit comments

Comments
 (0)