@@ -250,21 +250,32 @@ def requires_working_threading(*, module=False):
250250 return unittest .skipUnless (can_start_thread , msg )
251251
252252
253- def run_concurrently (worker_func , nthreads , args = (), kwargs = {}):
253+ def run_concurrently (worker_func , nthreads = None , args = (), kwargs = {}):
254254 """
255- Run the worker function concurrently in multiple threads.
255+ Run the worker function(s) concurrently in multiple threads.
256+
257+ If `worker_func` is a single callable, it is used for all threads.
258+ If it is a list of callables, each callable is used for one thread.
256259 """
260+ from collections .abc import Iterable
261+
262+ if nthreads is None :
263+ nthreads = len (worker_func )
264+ if not isinstance (worker_func , Iterable ):
265+ worker_func = [worker_func ] * nthreads
266+ assert len (worker_func ) == nthreads
267+
257268 barrier = threading .Barrier (nthreads )
258269
259- def wrapper_func (* args , ** kwargs ):
270+ def wrapper_func (func , * args , ** kwargs ):
260271 # Wait for all threads to reach this point before proceeding.
261272 barrier .wait ()
262- worker_func (* args , ** kwargs )
273+ func (* args , ** kwargs )
263274
264275 with catch_threading_exception () as cm :
265276 workers = [
266- threading .Thread (target = wrapper_func , args = args , kwargs = kwargs )
267- for _ in range ( nthreads )
277+ threading .Thread (target = wrapper_func , args = ( func , * args ) , kwargs = kwargs )
278+ for func in worker_func
268279 ]
269280 with start_threads (workers ):
270281 pass
0 commit comments