Skip to content

Commit 950d917

Browse files
committed
feat: api improved following code review
1 parent 3bec7fb commit 950d917

File tree

2 files changed

+99
-34
lines changed

2 files changed

+99
-34
lines changed

datashield/api.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DataSHIELD API.
33
"""
44

5+
import logging
56
from datashield.interface import DSLoginInfo, DSConnection, DSDriver, DSError
67
import time
78

@@ -72,13 +73,17 @@ class DSSession:
7273
DataSHIELD session, establishes connections with remote servers and performs commands.
7374
"""
7475

75-
def __init__(self, logins: list[DSLoginInfo]):
76+
def __init__(self, logins: list[DSLoginInfo], start_timeout: float = 300.0, start_delay: float = 0.1):
7677
"""
7778
Create a session, with connection information. Does not open the connections.
7879
7980
:param logins: A list of login details
81+
:param start_timeout: The maximum time in seconds to wait for R sessions to start, default is 300 seconds (5 minutes)
82+
:param start_delay: The delay in seconds between checking if R sessions are started, default is 0.1 seconds
8083
"""
8184
self.logins = logins
85+
self.start_timeout = start_timeout
86+
self.start_delay = start_delay
8287
self.conns: list[DSConnection] = None
8388
self.errors: dict = None
8489

@@ -234,61 +239,107 @@ def workspaces(self) -> dict:
234239
rval[conn.name] = conn.list_workspaces()
235240
return rval
236241

237-
def workspace_save(self, name: str) -> dict:
242+
def workspace_save(self, name: str) -> None:
238243
"""
239244
Save the DataSHIELD R session in a workspace on the remote data repository.
240245
241246
:param name: The name of the workspace
242-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after saving the workspace, per remote server name
243247
"""
244248
for conn in self.conns:
245249
conn.save_workspace(f"{conn.name}:{name}")
246-
return self.workspaces()
247250

248-
def workspace_restore(self, name: str) -> dict:
251+
def workspace_restore(self, name: str) -> None:
249252
"""
250253
Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace,
251254
any existing symbol or file with same name will be overridden.
252255
253256
:param name: The name of the workspace
254-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after restoring the workspace, per remote server name
255257
"""
256258
for conn in self.conns:
257259
conn.restore_workspace(f"{conn.name}:{name}")
258-
return self.workspaces()
259260

260-
def workspace_rm(self, name: str) -> dict:
261+
def workspace_rm(self, name: str) -> None:
261262
"""
262263
Remove a DataSHIELD workspace from the remote data repository. Ignored if no
263264
such workspace exists.
264265
265266
:param name: The name of the workspace
266-
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository after removing the workspace, per remote server name
267267
"""
268268
for conn in self.conns:
269269
conn.rm_workspace(f"{conn.name}:{name}")
270-
return self.workspaces()
271270

272271
#
273272
# R session
274273
#
275274

276275
def sessions(self) -> dict:
277276
"""
278-
Ensure R sessions are started on the remote servers and get their information.
279-
280-
:return: The R session information, per remote server name
277+
Ensure R sessions are started on the remote servers and wait until they are ready.
278+
This method returns a dictionary mapping each remote server name to its underlying
279+
R session object (an ``RSession`` instance). These session objects are
280+
primarily intended for status inspection (e.g. ``is_started()``, ``is_ready()``,
281+
``is_pending()``, ``is_failed()``, ``is_terminated()``, ``get_last_message()``) and
282+
not for direct interaction with the remote R environment.
283+
In normal use, you do not need to work with the returned session objects directly.
284+
Instead, interact with the remote R sessions through the higher-level ``DSSession``
285+
methods (such as assignment, aggregation, workspace and other helpers), which
286+
operate on all underlying sessions. Important: only sessions that have been successfully started
287+
and are ready will be included in the returned dictionary and used for subsequent operations.
288+
If a session fails to start or check status, it will be excluded from the returned dictionary
289+
and from subsequent operations, and an error will be logged. If no sessions can be started successfully,
290+
an exception will be raised.
291+
292+
:return: A dictionary mapping remote server names to their corresponding R session
293+
objects, intended mainly for internal use and status monitoring.
281294
"""
282295
rval = {}
296+
self._init_errors()
297+
started_conns = []
298+
excluded_conns = []
299+
300+
# start sessions asynchronously if supported, otherwise synchronously
283301
for conn in self.conns:
284-
if not conn.has_session():
285-
conn.start_session(asynchronous=True)
286-
# check for session status and wait until all are complete
287-
while any(conn.get_session().is_pending() for conn in self.conns):
288-
time.sleep(0.1)
289-
for conn in self.conns:
302+
try:
303+
if not conn.has_session():
304+
conn.start_session(asynchronous=True)
305+
except Exception as e:
306+
logging.warning(f"Failed to start session: {conn.name} - {e}")
307+
excluded_conns.append(conn.name)
308+
309+
# check for session status and wait until all are started
310+
for conn in [c for c in self.conns if c.name not in excluded_conns]:
311+
try:
312+
if conn.is_session_started():
313+
started_conns.append(conn.name)
314+
except Exception as e:
315+
logging.warning(f"Failed to check session status: {conn.name} - {e}")
316+
excluded_conns.append(conn.name)
317+
318+
# wait until all sessions are started, excluding those that have failed to start or check status
319+
start_time = time.time()
320+
while len(started_conns) < len(self.conns) - len(excluded_conns):
321+
if time.time() - start_time > self.start_timeout:
322+
raise DSError("Timed out waiting for R sessions to start")
323+
time.sleep(self.start_delay)
324+
remaining_conns = [
325+
conn for conn in self.conns if conn.name not in started_conns and conn.name not in excluded_conns
326+
]
327+
for conn in remaining_conns:
328+
try:
329+
if conn.is_session_started():
330+
started_conns.append(conn.name)
331+
except Exception as e:
332+
logging.warning(f"Failed to check session status: {conn.name} - {e}")
333+
excluded_conns.append(conn.name)
334+
335+
# at this point, all sessions that could be started have been started, and those that failed to start or check status have been excluded
336+
for conn in started_conns:
290337
rval[conn.name] = conn.get_session()
291-
self._check_errors()
338+
if len(excluded_conns) > 0:
339+
logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}")
340+
self.conns = [conn for conn in self.conns if conn.name not in excluded_conns]
341+
if len(self.conns) == 0:
342+
raise DSError("No sessions could be started successfully.")
292343
return rval
293344

294345
def ls(self) -> dict:
@@ -297,8 +348,8 @@ def ls(self) -> dict:
297348
298349
:return: The symbols that live in the DataSHIELD R session on the server side, per remote server name
299350
"""
300-
self._init_errors()
301-
self.sessions() # ensure sessions are started and available
351+
# ensure sessions are started and available
352+
self.sessions()
302353
rval = {}
303354
for conn in self.conns:
304355
try:
@@ -315,8 +366,8 @@ def rm(self, symbol: str) -> None:
315366
316367
:param symbol: The name of the symbol to remove
317368
"""
318-
self._init_errors()
319-
self.sessions() # ensure sessions are started and available
369+
# ensure sessions are started and available
370+
self.sessions()
320371
for conn in self.conns:
321372
try:
322373
conn.rm_symbol(symbol)
@@ -343,8 +394,8 @@ def assign_table(
343394
:param tables: The name of the table to assign, per server name. If not defined, 'table' is used.
344395
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
345396
"""
346-
self._init_errors()
347-
self.sessions() # ensure sessions are started and available
397+
# ensure sessions are started and available
398+
self.sessions()
348399
cmd = {}
349400
for conn in self.conns:
350401
name = table
@@ -370,8 +421,8 @@ def assign_resource(
370421
:param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used.
371422
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
372423
"""
373-
self._init_errors()
374-
self.sessions() # ensure sessions are started and available
424+
# ensure sessions are started and available
425+
self.sessions()
375426
cmd = {}
376427
for conn in self.conns:
377428
name = resource
@@ -394,8 +445,8 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None
394445
:param expr: The R expression to evaluate and which result will be assigned
395446
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
396447
"""
397-
self._init_errors()
398-
self.sessions() # ensure sessions are started and available
448+
# ensure sessions are started and available
449+
self.sessions()
399450
cmd = {}
400451
for conn in self.conns:
401452
try:
@@ -415,8 +466,8 @@ def aggregate(self, expr: str, asynchronous: bool = True) -> dict:
415466
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
416467
:return: The result of the aggregation expression evaluation, per remote server name
417468
"""
418-
self._init_errors()
419-
self.sessions() # ensure sessions are started and available
469+
# ensure sessions are started and available
470+
self.sessions()
420471
cmd = {}
421472
rval = {}
422473
for conn in self.conns:
@@ -465,6 +516,7 @@ def _append_error(self, conn: DSConnection, error: Exception) -> None:
465516
"""
466517
Append an error.
467518
"""
519+
logging.error(f"[{conn.name}] {error}")
468520
self.errors[conn.name] = error
469521

470522
def _check_errors(self) -> None:

datashield/interface.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def start_session(self, asynchronous: bool = True) -> RSession:
187187
"""
188188
raise NotImplementedError("DSConnection function not available")
189189

190+
def is_session_started(self) -> bool:
191+
"""
192+
Get whether the session with the DataSHIELD server is started. If the session start was asynchronous, this function
193+
can be used to check whether the session is started without waiting for it to be started. If the last call was positive,
194+
subsequent calls will not request the server for session status, but will return True directly. If the last call was negative,
195+
subsequent calls will request the server for session status until a positive response is obtained.
196+
197+
:return: Whether the session is started
198+
:throws: DSError if the session was not started or session information is not available
199+
"""
200+
raise NotImplementedError("DSConnection function not available")
201+
190202
def get_session(self) -> RSession:
191203
"""
192204
Get the R session with the DataSHIELD server. If no session is established, an error will be raised.
@@ -359,9 +371,10 @@ def is_async(self) -> dict:
359371
the raw result can be accessed asynchronously, allowing parallelization of DataSHIELD calls
360372
over multpile servers. The returned named list of logicals will specify if asynchronicity is supported for:
361373
aggregation operation ('aggregate'), table assignment operation ('assign_table'),
362-
resource assignment operation ('assign_resource') and expression assignment operation ('assign_expr').
374+
resource assignment operation ('assign_resource'), expression assignment operation ('assign_expr')
375+
and R session creation ('session').
363376
364-
:return: A named list of logicals specifying if asynchronicity is supported for aggregation operation ('aggregate'), table assignment operation ('assign_table'),
377+
:return: A named list of logicals specifying if asynchronicity is supported.
365378
"""
366379
raise NotImplementedError("DSConnection function not available")
367380

0 commit comments

Comments
 (0)