Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
584 lines
15 KiB
Python
584 lines
15 KiB
Python
"""Interface for parallelism."""
|
|
|
|
import atexit
|
|
import collections
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import numbers
|
|
import operator
|
|
import warnings
|
|
|
|
_AUTO_BACKEND = None
|
|
|
|
# check for loky, joblib (vendors loky), then default to concurrent.futures
|
|
have_loky = importlib.util.find_spec("loky") is not None
|
|
have_joblib = importlib.util.find_spec("joblib") is not None
|
|
if have_loky or have_joblib:
|
|
_DEFAULT_BACKEND = "loky"
|
|
else:
|
|
_DEFAULT_BACKEND = "concurrent.futures"
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def choose_default_num_workers():
|
|
import os
|
|
|
|
if "COTENGRA_NUM_WORKERS" in os.environ:
|
|
return int(os.environ["COTENGRA_NUM_WORKERS"])
|
|
|
|
if "OMP_NUM_THREADS" in os.environ:
|
|
return int(os.environ["OMP_NUM_THREADS"])
|
|
|
|
return os.cpu_count()
|
|
|
|
|
|
def get_pool(n_workers=None, maybe_create=False, backend=None):
|
|
"""Get a parallel pool."""
|
|
if backend is None:
|
|
backend = _DEFAULT_BACKEND
|
|
|
|
if backend == "dask":
|
|
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)
|
|
|
|
if backend == "ray":
|
|
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)
|
|
|
|
# above backends are distributed, don't specify n_workers
|
|
if n_workers is None:
|
|
n_workers = choose_default_num_workers()
|
|
|
|
if backend == "loky":
|
|
get_reusable_executor = get_loky_get_reusable_executor()
|
|
return get_reusable_executor(max_workers=n_workers)
|
|
|
|
if backend == "concurrent.futures":
|
|
return _get_process_pool_cf(n_workers=n_workers)
|
|
|
|
if backend == "threads":
|
|
return _get_thread_pool_cf(n_workers=n_workers)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _infer_backed_cached(pool_class):
|
|
if pool_class.__name__ == "RayExecutor":
|
|
return "ray"
|
|
|
|
path = pool_class.__module__.split(".")
|
|
|
|
if path[0] == "concurrent":
|
|
return "concurrent.futures"
|
|
|
|
if path[0] == "joblib":
|
|
return "loky"
|
|
|
|
if path[0] == "distributed":
|
|
return "dask"
|
|
|
|
return path[0]
|
|
|
|
|
|
def _infer_backend(pool):
|
|
"""Return the backend type of ``pool`` - cached for speed."""
|
|
return _infer_backed_cached(pool.__class__)
|
|
|
|
|
|
def get_n_workers(pool=None):
|
|
"""Extract how many workers our pool has (mostly for working out how many
|
|
tasks to pre-dispatch).
|
|
"""
|
|
if pool is None:
|
|
pool = get_pool()
|
|
|
|
try:
|
|
return pool._max_workers
|
|
except AttributeError:
|
|
pass
|
|
|
|
backend = _infer_backend(pool)
|
|
|
|
if backend == "dask":
|
|
workers = pool.scheduler_info(n_workers=-1)["workers"]
|
|
return sum(int(w.get("nthreads", 1) or 1) for w in workers.values())
|
|
|
|
if backend == "ray":
|
|
while True:
|
|
try:
|
|
return int(get_ray().available_resources()["CPU"])
|
|
except KeyError:
|
|
import time
|
|
|
|
time.sleep(1e-3)
|
|
|
|
if backend == "mpi4py":
|
|
from mpi4py import MPI
|
|
|
|
return MPI.COMM_WORLD.size
|
|
|
|
raise ValueError(f"Can't find number of workers in pool {pool}.")
|
|
|
|
|
|
def parse_parallel_arg(parallel):
|
|
""" """
|
|
global _AUTO_BACKEND
|
|
|
|
if parallel == "auto":
|
|
return get_pool(maybe_create=False, backend=_AUTO_BACKEND)
|
|
|
|
if parallel is False:
|
|
return None
|
|
|
|
if parallel is True:
|
|
if _AUTO_BACKEND is None:
|
|
_AUTO_BACKEND = _DEFAULT_BACKEND
|
|
parallel = _AUTO_BACKEND
|
|
|
|
if isinstance(parallel, numbers.Integral):
|
|
_AUTO_BACKEND = _DEFAULT_BACKEND
|
|
return get_pool(
|
|
n_workers=parallel, maybe_create=True, backend=_DEFAULT_BACKEND
|
|
)
|
|
|
|
if parallel == "loky":
|
|
return get_pool(maybe_create=True, backend="loky")
|
|
|
|
if parallel == "concurrent.futures":
|
|
return get_pool(maybe_create=True, backend="concurrent.futures")
|
|
|
|
if parallel == "threads":
|
|
return get_pool(maybe_create=True, backend="threads")
|
|
|
|
if parallel == "dask":
|
|
_AUTO_BACKEND = "dask"
|
|
return get_pool(maybe_create=True, backend="dask")
|
|
|
|
if parallel == "ray":
|
|
_AUTO_BACKEND = "ray"
|
|
return get_pool(maybe_create=True, backend="ray")
|
|
|
|
return parallel
|
|
|
|
|
|
def set_parallel_backend(backend):
|
|
"""Create a parallel pool of type ``backend`` which registers it as the
|
|
default for ``'auto'`` parallel.
|
|
"""
|
|
return parse_parallel_arg(backend)
|
|
|
|
|
|
def maybe_leave_pool(pool):
|
|
"""Logic required for nested parallelism in dask.distributed."""
|
|
if _infer_backend(pool) == "dask":
|
|
return _maybe_leave_pool_dask()
|
|
|
|
|
|
def maybe_rejoin_pool(is_worker, pool):
|
|
"""Logic required for nested parallelism in dask.distributed."""
|
|
if is_worker and _infer_backend(pool) == "dask":
|
|
_rejoin_pool_dask()
|
|
|
|
|
|
def submit(pool, fn, *args, **kwargs):
|
|
"""Interface for submitting ``fn(*args, **kwargs)`` to ``pool``."""
|
|
if _infer_backend(pool) == "dask":
|
|
kwargs.setdefault("pure", False)
|
|
return pool.submit(fn, *args, **kwargs)
|
|
|
|
|
|
def scatter(pool, data):
|
|
"""Interface for maybe turning ``data`` into a remote object or reference."""
|
|
if _infer_backend(pool) in ("dask", "ray"):
|
|
return pool.scatter(data)
|
|
return data
|
|
|
|
|
|
def can_scatter(pool):
|
|
"""Whether ``pool`` can make objects remote."""
|
|
return _infer_backend(pool) in ("dask", "ray")
|
|
|
|
|
|
def should_nest(pool):
|
|
"""Given argument ``pool`` should we try nested parallelism."""
|
|
if pool is None:
|
|
return False
|
|
backend = _infer_backend(pool)
|
|
if backend in ("ray", "dask"):
|
|
return backend
|
|
return False
|
|
|
|
|
|
# ---------------------------------- loky ----------------------------------- #
|
|
|
|
|
|
@functools.lru_cache(1)
|
|
def get_loky_get_reusable_executor():
|
|
try:
|
|
from loky import get_reusable_executor
|
|
except ImportError:
|
|
from joblib.externals.loky import get_reusable_executor
|
|
return get_reusable_executor
|
|
|
|
|
|
# --------------------------- concurrent.futures ---------------------------- #
|
|
|
|
|
|
class CachedProcessPoolExecutor:
|
|
def __init__(self):
|
|
self._pool = None
|
|
self._n_workers = -1
|
|
atexit.register(self.shutdown)
|
|
|
|
def __call__(self, n_workers=None):
|
|
if n_workers != self._n_workers:
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
self.shutdown()
|
|
self._pool = ProcessPoolExecutor(n_workers)
|
|
self._n_workers = n_workers
|
|
return self._pool
|
|
|
|
def is_initialized(self):
|
|
return self._pool is not None
|
|
|
|
def shutdown(self):
|
|
if self._pool is not None:
|
|
self._pool.shutdown()
|
|
self._pool = None
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
|
|
ProcessPoolHandler = CachedProcessPoolExecutor()
|
|
|
|
|
|
def _get_process_pool_cf(n_workers=None):
|
|
return ProcessPoolHandler(n_workers)
|
|
|
|
|
|
class CachedThreadPoolExecutor:
|
|
def __init__(self):
|
|
self._pool = None
|
|
self._n_workers = -1
|
|
atexit.register(self.shutdown)
|
|
|
|
def __call__(self, n_workers=None):
|
|
if n_workers != self._n_workers:
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
self.shutdown()
|
|
self._pool = ThreadPoolExecutor(n_workers)
|
|
self._n_workers = n_workers
|
|
return self._pool
|
|
|
|
def is_initialized(self):
|
|
return self._pool is not None
|
|
|
|
def shutdown(self):
|
|
if self._pool is not None:
|
|
self._pool.shutdown()
|
|
self._pool = None
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
|
|
ThreadPoolHandler = CachedThreadPoolExecutor()
|
|
|
|
|
|
def _get_thread_pool_cf(n_workers=None):
|
|
return ThreadPoolHandler(n_workers)
|
|
|
|
|
|
# ---------------------------------- DASK ----------------------------------- #
|
|
|
|
|
|
def _get_pool_dask(n_workers=None, maybe_create=False):
|
|
"""Maybe get an existing or create a new dask.distrbuted client.
|
|
|
|
Parameters
|
|
----------
|
|
n_workers : None or int, optional
|
|
The number of workers to request if creating a new client.
|
|
maybe_create : bool, optional
|
|
Whether to create an new local cluster and client if no existing client
|
|
is found.
|
|
|
|
Returns
|
|
-------
|
|
None or dask.distributed.Client
|
|
"""
|
|
try:
|
|
from dask.distributed import get_client
|
|
except ImportError:
|
|
if not maybe_create:
|
|
return None
|
|
else:
|
|
raise
|
|
|
|
try:
|
|
client = get_client()
|
|
except ValueError:
|
|
if not maybe_create:
|
|
return None
|
|
|
|
import shutil
|
|
import tempfile
|
|
|
|
from dask.distributed import Client, LocalCluster
|
|
|
|
local_directory = tempfile.mkdtemp()
|
|
lc = LocalCluster(
|
|
n_workers=n_workers,
|
|
threads_per_worker=1,
|
|
local_directory=local_directory,
|
|
memory_limit=0,
|
|
)
|
|
client = Client(lc)
|
|
|
|
warnings.warn(
|
|
"Parallel specified but no existing global dask client found... "
|
|
"created one (with {} workers).".format(get_n_workers(client))
|
|
)
|
|
|
|
@atexit.register
|
|
def delete_local_dask_directory():
|
|
shutil.rmtree(local_directory, ignore_errors=True)
|
|
|
|
if n_workers is not None:
|
|
current_n_workers = get_n_workers(client)
|
|
if n_workers != current_n_workers:
|
|
warnings.warn(
|
|
"Found existing client (with {} workers which) doesn't match "
|
|
"the requested {}... using it instead.".format(
|
|
current_n_workers, n_workers
|
|
)
|
|
)
|
|
|
|
return client
|
|
|
|
|
|
def _maybe_leave_pool_dask():
|
|
try:
|
|
from dask.distributed import secede
|
|
|
|
secede() # for nested parallelism
|
|
is_dask_worker = True
|
|
except (ImportError, ValueError):
|
|
is_dask_worker = False
|
|
return is_dask_worker
|
|
|
|
|
|
def _rejoin_pool_dask():
|
|
from dask.distributed import rejoin
|
|
|
|
rejoin()
|
|
|
|
|
|
# ----------------------------------- RAY ----------------------------------- #
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_ray():
|
|
""" """
|
|
import ray
|
|
|
|
return ray
|
|
|
|
|
|
class RayFuture:
|
|
"""Basic ``concurrent.futures`` like future wrapping a ray ``ObjectRef``."""
|
|
|
|
__slots__ = ("_obj", "_cancelled")
|
|
|
|
def __init__(self, obj):
|
|
self._obj = obj
|
|
self._cancelled = False
|
|
|
|
def result(self, timeout=None):
|
|
return get_ray().get(self._obj, timeout=timeout)
|
|
|
|
def done(self):
|
|
return self._cancelled or bool(
|
|
get_ray().wait([self._obj], timeout=0)[0]
|
|
)
|
|
|
|
def cancel(self):
|
|
get_ray().cancel(self._obj)
|
|
self._cancelled = True
|
|
|
|
|
|
def _unpack_futures_tuple(x):
|
|
return tuple(map(_unpack_futures, x))
|
|
|
|
|
|
def _unpack_futures_list(x):
|
|
return list(map(_unpack_futures, x))
|
|
|
|
|
|
def _unpack_futures_dict(x):
|
|
return {k: _unpack_futures(v) for k, v in x.items()}
|
|
|
|
|
|
def _unpack_futures_identity(x):
|
|
return x
|
|
|
|
|
|
_unpack_dispatch = collections.defaultdict(
|
|
lambda: _unpack_futures_identity,
|
|
{
|
|
RayFuture: operator.attrgetter("_obj"),
|
|
tuple: _unpack_futures_tuple,
|
|
list: _unpack_futures_list,
|
|
dict: _unpack_futures_dict,
|
|
},
|
|
)
|
|
|
|
|
|
def _unpack_futures(x):
|
|
"""Allows passing futures by reference - takes e.g. args and kwargs and
|
|
replaces all ``RayFuture`` objects with their underyling ``ObjectRef``
|
|
within all nested tuples, lists and dicts.
|
|
|
|
[Subclassing ``ObjectRef`` might avoid needing this.]
|
|
"""
|
|
return _unpack_dispatch[x.__class__](x)
|
|
|
|
|
|
@functools.lru_cache(2**14)
|
|
def get_remote_fn(fn, **remote_opts):
|
|
"""Cached retrieval of remote function."""
|
|
ray = get_ray()
|
|
if remote_opts:
|
|
return ray.remote(**remote_opts)(fn)
|
|
return ray.remote(fn)
|
|
|
|
|
|
@functools.lru_cache(2**14)
|
|
def get_fn_as_remote_object(fn):
|
|
ray = get_ray()
|
|
return ray.put(fn)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_deploy(**remote_opts):
|
|
"""Alternative for 'non-function' callables - e.g. partial
|
|
functions - pass the callable object too.
|
|
"""
|
|
ray = get_ray()
|
|
|
|
def deploy(fn, *args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
if remote_opts:
|
|
return ray.remote(**remote_opts)(deploy)
|
|
return ray.remote(deploy)
|
|
|
|
|
|
class RayExecutor:
|
|
"""Basic ``concurrent.futures`` like interface using ``ray``."""
|
|
|
|
def __init__(self, *args, default_remote_opts=None, **kwargs):
|
|
ray = get_ray()
|
|
if not ray.is_initialized():
|
|
ray.init(*args, **kwargs)
|
|
|
|
self.default_remote_opts = (
|
|
{} if default_remote_opts is None else dict(default_remote_opts)
|
|
)
|
|
|
|
def _maybe_inject_remote_opts(self, remote_opts=None):
|
|
"""Return the default remote options, possibly overriding some with
|
|
those supplied by a ``submit call``.
|
|
"""
|
|
ropts = self.default_remote_opts
|
|
if remote_opts is not None:
|
|
ropts = {**ropts, **remote_opts}
|
|
return ropts
|
|
|
|
def submit(self, fn, *args, pure=False, remote_opts=None, **kwargs):
|
|
"""Remotely run ``fn(*args, **kwargs)``, returning a ``RayFuture``."""
|
|
# want to pass futures by reference
|
|
args = _unpack_futures_tuple(args)
|
|
kwargs = _unpack_futures_dict(kwargs)
|
|
|
|
ropts = self._maybe_inject_remote_opts(remote_opts)
|
|
|
|
# this is the same test ray uses to accept functions
|
|
if inspect.isfunction(fn):
|
|
# can use the faster cached remote function
|
|
obj = get_remote_fn(fn, **ropts).remote(*args, **kwargs)
|
|
else:
|
|
fn_obj = get_fn_as_remote_object(fn)
|
|
obj = get_deploy(**ropts).remote(fn_obj, *args, **kwargs)
|
|
|
|
return RayFuture(obj)
|
|
|
|
def map(self, func, *iterables, remote_opts=None):
|
|
"""Remote map ``func`` over arguments ``iterables``."""
|
|
ropts = self._maybe_inject_remote_opts(remote_opts)
|
|
remote_fn = get_remote_fn(func, **ropts)
|
|
objs = tuple(map(remote_fn.remote, *iterables))
|
|
ray = get_ray()
|
|
return map(ray.get, objs)
|
|
|
|
def scatter(self, data):
|
|
"""Push ``data`` into the distributed store, returning an ``ObjectRef``
|
|
that can be supplied to ``submit`` calls for example.
|
|
"""
|
|
ray = get_ray()
|
|
return ray.put(data)
|
|
|
|
def shutdown(self):
|
|
"""Shutdown the parent ray cluster, this ``RayExecutor`` instance
|
|
itself does not need any cleanup.
|
|
"""
|
|
get_ray().shutdown()
|
|
|
|
|
|
_RAY_EXECUTOR = None
|
|
|
|
|
|
def _get_pool_ray(n_workers=None, maybe_create=False):
|
|
"""Maybe get an existing or create a new RayExecutor, thus initializing,
|
|
ray.
|
|
|
|
Parameters
|
|
----------
|
|
n_workers : None or int, optional
|
|
The number of workers to request if creating a new client.
|
|
maybe_create : bool, optional
|
|
Whether to create initialize ray and return a RayExecutor if not
|
|
initialized already.
|
|
|
|
Returns
|
|
-------
|
|
None or RayExecutor
|
|
"""
|
|
try:
|
|
import ray
|
|
except ImportError:
|
|
if not maybe_create:
|
|
return None
|
|
else:
|
|
raise
|
|
|
|
global _RAY_EXECUTOR
|
|
|
|
if (_RAY_EXECUTOR is None) or (not ray.is_initialized()):
|
|
if not maybe_create:
|
|
return None
|
|
_RAY_EXECUTOR = RayExecutor(num_cpus=n_workers)
|
|
|
|
if n_workers is not None:
|
|
current_n_workers = get_n_workers(_RAY_EXECUTOR)
|
|
if n_workers != current_n_workers:
|
|
warnings.warn(
|
|
"Found initialized ray (with {} workers which) doesn't match "
|
|
"the requested {}... sticking with old number.".format(
|
|
current_n_workers, n_workers
|
|
)
|
|
)
|
|
|
|
return _RAY_EXECUTOR
|