Files
qibotn/.venv/lib/python3.12/site-packages/cotengra/parallel.py
jaunatisblue 28080dff1d
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
简化代码;加入.venv下内容
2026-05-18 02:47:40 +08:00

584 lines
15 KiB
Python

"""Interface for parallelism."""
import atexit
import collections
import functools
import importlib
import inspect
import numbers
import operator
import warnings
_AUTO_BACKEND = None
# check for loky, joblib (vendors loky), then default to concurrent.futures
have_loky = importlib.util.find_spec("loky") is not None
have_joblib = importlib.util.find_spec("joblib") is not None
if have_loky or have_joblib:
_DEFAULT_BACKEND = "loky"
else:
_DEFAULT_BACKEND = "concurrent.futures"
@functools.lru_cache(None)
def choose_default_num_workers():
import os
if "COTENGRA_NUM_WORKERS" in os.environ:
return int(os.environ["COTENGRA_NUM_WORKERS"])
if "OMP_NUM_THREADS" in os.environ:
return int(os.environ["OMP_NUM_THREADS"])
return os.cpu_count()
def get_pool(n_workers=None, maybe_create=False, backend=None):
"""Get a parallel pool."""
if backend is None:
backend = _DEFAULT_BACKEND
if backend == "dask":
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)
if backend == "ray":
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)
# above backends are distributed, don't specify n_workers
if n_workers is None:
n_workers = choose_default_num_workers()
if backend == "loky":
get_reusable_executor = get_loky_get_reusable_executor()
return get_reusable_executor(max_workers=n_workers)
if backend == "concurrent.futures":
return _get_process_pool_cf(n_workers=n_workers)
if backend == "threads":
return _get_thread_pool_cf(n_workers=n_workers)
@functools.lru_cache(None)
def _infer_backed_cached(pool_class):
if pool_class.__name__ == "RayExecutor":
return "ray"
path = pool_class.__module__.split(".")
if path[0] == "concurrent":
return "concurrent.futures"
if path[0] == "joblib":
return "loky"
if path[0] == "distributed":
return "dask"
return path[0]
def _infer_backend(pool):
"""Return the backend type of ``pool`` - cached for speed."""
return _infer_backed_cached(pool.__class__)
def get_n_workers(pool=None):
"""Extract how many workers our pool has (mostly for working out how many
tasks to pre-dispatch).
"""
if pool is None:
pool = get_pool()
try:
return pool._max_workers
except AttributeError:
pass
backend = _infer_backend(pool)
if backend == "dask":
workers = pool.scheduler_info(n_workers=-1)["workers"]
return sum(int(w.get("nthreads", 1) or 1) for w in workers.values())
if backend == "ray":
while True:
try:
return int(get_ray().available_resources()["CPU"])
except KeyError:
import time
time.sleep(1e-3)
if backend == "mpi4py":
from mpi4py import MPI
return MPI.COMM_WORLD.size
raise ValueError(f"Can't find number of workers in pool {pool}.")
def parse_parallel_arg(parallel):
""" """
global _AUTO_BACKEND
if parallel == "auto":
return get_pool(maybe_create=False, backend=_AUTO_BACKEND)
if parallel is False:
return None
if parallel is True:
if _AUTO_BACKEND is None:
_AUTO_BACKEND = _DEFAULT_BACKEND
parallel = _AUTO_BACKEND
if isinstance(parallel, numbers.Integral):
_AUTO_BACKEND = _DEFAULT_BACKEND
return get_pool(
n_workers=parallel, maybe_create=True, backend=_DEFAULT_BACKEND
)
if parallel == "loky":
return get_pool(maybe_create=True, backend="loky")
if parallel == "concurrent.futures":
return get_pool(maybe_create=True, backend="concurrent.futures")
if parallel == "threads":
return get_pool(maybe_create=True, backend="threads")
if parallel == "dask":
_AUTO_BACKEND = "dask"
return get_pool(maybe_create=True, backend="dask")
if parallel == "ray":
_AUTO_BACKEND = "ray"
return get_pool(maybe_create=True, backend="ray")
return parallel
def set_parallel_backend(backend):
"""Create a parallel pool of type ``backend`` which registers it as the
default for ``'auto'`` parallel.
"""
return parse_parallel_arg(backend)
def maybe_leave_pool(pool):
"""Logic required for nested parallelism in dask.distributed."""
if _infer_backend(pool) == "dask":
return _maybe_leave_pool_dask()
def maybe_rejoin_pool(is_worker, pool):
"""Logic required for nested parallelism in dask.distributed."""
if is_worker and _infer_backend(pool) == "dask":
_rejoin_pool_dask()
def submit(pool, fn, *args, **kwargs):
"""Interface for submitting ``fn(*args, **kwargs)`` to ``pool``."""
if _infer_backend(pool) == "dask":
kwargs.setdefault("pure", False)
return pool.submit(fn, *args, **kwargs)
def scatter(pool, data):
"""Interface for maybe turning ``data`` into a remote object or reference."""
if _infer_backend(pool) in ("dask", "ray"):
return pool.scatter(data)
return data
def can_scatter(pool):
"""Whether ``pool`` can make objects remote."""
return _infer_backend(pool) in ("dask", "ray")
def should_nest(pool):
"""Given argument ``pool`` should we try nested parallelism."""
if pool is None:
return False
backend = _infer_backend(pool)
if backend in ("ray", "dask"):
return backend
return False
# ---------------------------------- loky ----------------------------------- #
@functools.lru_cache(1)
def get_loky_get_reusable_executor():
try:
from loky import get_reusable_executor
except ImportError:
from joblib.externals.loky import get_reusable_executor
return get_reusable_executor
# --------------------------- concurrent.futures ---------------------------- #
class CachedProcessPoolExecutor:
def __init__(self):
self._pool = None
self._n_workers = -1
atexit.register(self.shutdown)
def __call__(self, n_workers=None):
if n_workers != self._n_workers:
from concurrent.futures import ProcessPoolExecutor
self.shutdown()
self._pool = ProcessPoolExecutor(n_workers)
self._n_workers = n_workers
return self._pool
def is_initialized(self):
return self._pool is not None
def shutdown(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None
def __del__(self):
self.shutdown()
ProcessPoolHandler = CachedProcessPoolExecutor()
def _get_process_pool_cf(n_workers=None):
return ProcessPoolHandler(n_workers)
class CachedThreadPoolExecutor:
def __init__(self):
self._pool = None
self._n_workers = -1
atexit.register(self.shutdown)
def __call__(self, n_workers=None):
if n_workers != self._n_workers:
from concurrent.futures import ThreadPoolExecutor
self.shutdown()
self._pool = ThreadPoolExecutor(n_workers)
self._n_workers = n_workers
return self._pool
def is_initialized(self):
return self._pool is not None
def shutdown(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None
def __del__(self):
self.shutdown()
ThreadPoolHandler = CachedThreadPoolExecutor()
def _get_thread_pool_cf(n_workers=None):
return ThreadPoolHandler(n_workers)
# ---------------------------------- DASK ----------------------------------- #
def _get_pool_dask(n_workers=None, maybe_create=False):
"""Maybe get an existing or create a new dask.distrbuted client.
Parameters
----------
n_workers : None or int, optional
The number of workers to request if creating a new client.
maybe_create : bool, optional
Whether to create an new local cluster and client if no existing client
is found.
Returns
-------
None or dask.distributed.Client
"""
try:
from dask.distributed import get_client
except ImportError:
if not maybe_create:
return None
else:
raise
try:
client = get_client()
except ValueError:
if not maybe_create:
return None
import shutil
import tempfile
from dask.distributed import Client, LocalCluster
local_directory = tempfile.mkdtemp()
lc = LocalCluster(
n_workers=n_workers,
threads_per_worker=1,
local_directory=local_directory,
memory_limit=0,
)
client = Client(lc)
warnings.warn(
"Parallel specified but no existing global dask client found... "
"created one (with {} workers).".format(get_n_workers(client))
)
@atexit.register
def delete_local_dask_directory():
shutil.rmtree(local_directory, ignore_errors=True)
if n_workers is not None:
current_n_workers = get_n_workers(client)
if n_workers != current_n_workers:
warnings.warn(
"Found existing client (with {} workers which) doesn't match "
"the requested {}... using it instead.".format(
current_n_workers, n_workers
)
)
return client
def _maybe_leave_pool_dask():
try:
from dask.distributed import secede
secede() # for nested parallelism
is_dask_worker = True
except (ImportError, ValueError):
is_dask_worker = False
return is_dask_worker
def _rejoin_pool_dask():
from dask.distributed import rejoin
rejoin()
# ----------------------------------- RAY ----------------------------------- #
@functools.lru_cache(None)
def get_ray():
""" """
import ray
return ray
class RayFuture:
"""Basic ``concurrent.futures`` like future wrapping a ray ``ObjectRef``."""
__slots__ = ("_obj", "_cancelled")
def __init__(self, obj):
self._obj = obj
self._cancelled = False
def result(self, timeout=None):
return get_ray().get(self._obj, timeout=timeout)
def done(self):
return self._cancelled or bool(
get_ray().wait([self._obj], timeout=0)[0]
)
def cancel(self):
get_ray().cancel(self._obj)
self._cancelled = True
def _unpack_futures_tuple(x):
return tuple(map(_unpack_futures, x))
def _unpack_futures_list(x):
return list(map(_unpack_futures, x))
def _unpack_futures_dict(x):
return {k: _unpack_futures(v) for k, v in x.items()}
def _unpack_futures_identity(x):
return x
_unpack_dispatch = collections.defaultdict(
lambda: _unpack_futures_identity,
{
RayFuture: operator.attrgetter("_obj"),
tuple: _unpack_futures_tuple,
list: _unpack_futures_list,
dict: _unpack_futures_dict,
},
)
def _unpack_futures(x):
"""Allows passing futures by reference - takes e.g. args and kwargs and
replaces all ``RayFuture`` objects with their underyling ``ObjectRef``
within all nested tuples, lists and dicts.
[Subclassing ``ObjectRef`` might avoid needing this.]
"""
return _unpack_dispatch[x.__class__](x)
@functools.lru_cache(2**14)
def get_remote_fn(fn, **remote_opts):
"""Cached retrieval of remote function."""
ray = get_ray()
if remote_opts:
return ray.remote(**remote_opts)(fn)
return ray.remote(fn)
@functools.lru_cache(2**14)
def get_fn_as_remote_object(fn):
ray = get_ray()
return ray.put(fn)
@functools.lru_cache(None)
def get_deploy(**remote_opts):
"""Alternative for 'non-function' callables - e.g. partial
functions - pass the callable object too.
"""
ray = get_ray()
def deploy(fn, *args, **kwargs):
return fn(*args, **kwargs)
if remote_opts:
return ray.remote(**remote_opts)(deploy)
return ray.remote(deploy)
class RayExecutor:
"""Basic ``concurrent.futures`` like interface using ``ray``."""
def __init__(self, *args, default_remote_opts=None, **kwargs):
ray = get_ray()
if not ray.is_initialized():
ray.init(*args, **kwargs)
self.default_remote_opts = (
{} if default_remote_opts is None else dict(default_remote_opts)
)
def _maybe_inject_remote_opts(self, remote_opts=None):
"""Return the default remote options, possibly overriding some with
those supplied by a ``submit call``.
"""
ropts = self.default_remote_opts
if remote_opts is not None:
ropts = {**ropts, **remote_opts}
return ropts
def submit(self, fn, *args, pure=False, remote_opts=None, **kwargs):
"""Remotely run ``fn(*args, **kwargs)``, returning a ``RayFuture``."""
# want to pass futures by reference
args = _unpack_futures_tuple(args)
kwargs = _unpack_futures_dict(kwargs)
ropts = self._maybe_inject_remote_opts(remote_opts)
# this is the same test ray uses to accept functions
if inspect.isfunction(fn):
# can use the faster cached remote function
obj = get_remote_fn(fn, **ropts).remote(*args, **kwargs)
else:
fn_obj = get_fn_as_remote_object(fn)
obj = get_deploy(**ropts).remote(fn_obj, *args, **kwargs)
return RayFuture(obj)
def map(self, func, *iterables, remote_opts=None):
"""Remote map ``func`` over arguments ``iterables``."""
ropts = self._maybe_inject_remote_opts(remote_opts)
remote_fn = get_remote_fn(func, **ropts)
objs = tuple(map(remote_fn.remote, *iterables))
ray = get_ray()
return map(ray.get, objs)
def scatter(self, data):
"""Push ``data`` into the distributed store, returning an ``ObjectRef``
that can be supplied to ``submit`` calls for example.
"""
ray = get_ray()
return ray.put(data)
def shutdown(self):
"""Shutdown the parent ray cluster, this ``RayExecutor`` instance
itself does not need any cleanup.
"""
get_ray().shutdown()
_RAY_EXECUTOR = None
def _get_pool_ray(n_workers=None, maybe_create=False):
"""Maybe get an existing or create a new RayExecutor, thus initializing,
ray.
Parameters
----------
n_workers : None or int, optional
The number of workers to request if creating a new client.
maybe_create : bool, optional
Whether to create initialize ray and return a RayExecutor if not
initialized already.
Returns
-------
None or RayExecutor
"""
try:
import ray
except ImportError:
if not maybe_create:
return None
else:
raise
global _RAY_EXECUTOR
if (_RAY_EXECUTOR is None) or (not ray.is_initialized()):
if not maybe_create:
return None
_RAY_EXECUTOR = RayExecutor(num_cpus=n_workers)
if n_workers is not None:
current_n_workers = get_n_workers(_RAY_EXECUTOR)
if n_workers != current_n_workers:
warnings.warn(
"Found initialized ray (with {} workers which) doesn't match "
"the requested {}... sticking with old number.".format(
current_n_workers, n_workers
)
)
return _RAY_EXECUTOR