Skip to content

Commit 7e79bfb

Browse files
committed
Update run_concurrently from main branch
1 parent d92a7fb commit 7e79bfb

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

Lib/test/support/threading_helper.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,21 +250,32 @@ def requires_working_threading(*, module=False):
250250
return unittest.skipUnless(can_start_thread, msg)
251251

252252

253-
def run_concurrently(worker_func, nthreads, args=(), kwargs={}):
253+
def run_concurrently(worker_func, nthreads=None, args=(), kwargs={}):
254254
"""
255-
Run the worker function concurrently in multiple threads.
255+
Run the worker function(s) concurrently in multiple threads.
256+
257+
If `worker_func` is a single callable, it is used for all threads.
258+
If it is a list of callables, each callable is used for one thread.
256259
"""
260+
from collections.abc import Iterable
261+
262+
if nthreads is None:
263+
nthreads = len(worker_func)
264+
if not isinstance(worker_func, Iterable):
265+
worker_func = [worker_func] * nthreads
266+
assert len(worker_func) == nthreads
267+
257268
barrier = threading.Barrier(nthreads)
258269

259-
def wrapper_func(*args, **kwargs):
270+
def wrapper_func(func, *args, **kwargs):
260271
# Wait for all threads to reach this point before proceeding.
261272
barrier.wait()
262-
worker_func(*args, **kwargs)
273+
func(*args, **kwargs)
263274

264275
with catch_threading_exception() as cm:
265276
workers = [
266-
threading.Thread(target=wrapper_func, args=args, kwargs=kwargs)
267-
for _ in range(nthreads)
277+
threading.Thread(target=wrapper_func, args=(func, *args), kwargs=kwargs)
278+
for func in worker_func
268279
]
269280
with start_threads(workers):
270281
pass

0 commit comments

Comments
 (0)