简化代码;加入.venv下内容
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
1288
.venv/lib/python3.12/site-packages/cotengra/contract.py
Normal file
1288
.venv/lib/python3.12/site-packages/cotengra/contract.py
Normal file
File diff suppressed because it is too large
Load Diff
4130
.venv/lib/python3.12/site-packages/cotengra/core.py
Normal file
4130
.venv/lib/python3.12/site-packages/cotengra/core.py
Normal file
File diff suppressed because it is too large
Load Diff
1168
.venv/lib/python3.12/site-packages/cotengra/hyperoptimizers/hyper.py
Normal file
1168
.venv/lib/python3.12/site-packages/cotengra/hyperoptimizers/hyper.py
Normal file
File diff suppressed because it is too large
Load Diff
583
.venv/lib/python3.12/site-packages/cotengra/parallel.py
Normal file
583
.venv/lib/python3.12/site-packages/cotengra/parallel.py
Normal file
@@ -0,0 +1,583 @@
|
||||
"""Interface for parallelism."""
|
||||
|
||||
import atexit
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import numbers
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
_AUTO_BACKEND = None
|
||||
|
||||
# check for loky, joblib (vendors loky), then default to concurrent.futures
|
||||
have_loky = importlib.util.find_spec("loky") is not None
|
||||
have_joblib = importlib.util.find_spec("joblib") is not None
|
||||
if have_loky or have_joblib:
|
||||
_DEFAULT_BACKEND = "loky"
|
||||
else:
|
||||
_DEFAULT_BACKEND = "concurrent.futures"
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def choose_default_num_workers():
|
||||
import os
|
||||
|
||||
if "COTENGRA_NUM_WORKERS" in os.environ:
|
||||
return int(os.environ["COTENGRA_NUM_WORKERS"])
|
||||
|
||||
if "OMP_NUM_THREADS" in os.environ:
|
||||
return int(os.environ["OMP_NUM_THREADS"])
|
||||
|
||||
return os.cpu_count()
|
||||
|
||||
|
||||
def get_pool(n_workers=None, maybe_create=False, backend=None):
|
||||
"""Get a parallel pool."""
|
||||
if backend is None:
|
||||
backend = _DEFAULT_BACKEND
|
||||
|
||||
if backend == "dask":
|
||||
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)
|
||||
|
||||
if backend == "ray":
|
||||
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)
|
||||
|
||||
# above backends are distributed, don't specify n_workers
|
||||
if n_workers is None:
|
||||
n_workers = choose_default_num_workers()
|
||||
|
||||
if backend == "loky":
|
||||
get_reusable_executor = get_loky_get_reusable_executor()
|
||||
return get_reusable_executor(max_workers=n_workers)
|
||||
|
||||
if backend == "concurrent.futures":
|
||||
return _get_process_pool_cf(n_workers=n_workers)
|
||||
|
||||
if backend == "threads":
|
||||
return _get_thread_pool_cf(n_workers=n_workers)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _infer_backed_cached(pool_class):
|
||||
if pool_class.__name__ == "RayExecutor":
|
||||
return "ray"
|
||||
|
||||
path = pool_class.__module__.split(".")
|
||||
|
||||
if path[0] == "concurrent":
|
||||
return "concurrent.futures"
|
||||
|
||||
if path[0] == "joblib":
|
||||
return "loky"
|
||||
|
||||
if path[0] == "distributed":
|
||||
return "dask"
|
||||
|
||||
return path[0]
|
||||
|
||||
|
||||
def _infer_backend(pool):
|
||||
"""Return the backend type of ``pool`` - cached for speed."""
|
||||
return _infer_backed_cached(pool.__class__)
|
||||
|
||||
|
||||
def get_n_workers(pool=None):
|
||||
"""Extract how many workers our pool has (mostly for working out how many
|
||||
tasks to pre-dispatch).
|
||||
"""
|
||||
if pool is None:
|
||||
pool = get_pool()
|
||||
|
||||
try:
|
||||
return pool._max_workers
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
backend = _infer_backend(pool)
|
||||
|
||||
if backend == "dask":
|
||||
workers = pool.scheduler_info(n_workers=-1)["workers"]
|
||||
return sum(int(w.get("nthreads", 1) or 1) for w in workers.values())
|
||||
|
||||
if backend == "ray":
|
||||
while True:
|
||||
try:
|
||||
return int(get_ray().available_resources()["CPU"])
|
||||
except KeyError:
|
||||
import time
|
||||
|
||||
time.sleep(1e-3)
|
||||
|
||||
if backend == "mpi4py":
|
||||
from mpi4py import MPI
|
||||
|
||||
return MPI.COMM_WORLD.size
|
||||
|
||||
raise ValueError(f"Can't find number of workers in pool {pool}.")
|
||||
|
||||
|
||||
def parse_parallel_arg(parallel):
|
||||
""" """
|
||||
global _AUTO_BACKEND
|
||||
|
||||
if parallel == "auto":
|
||||
return get_pool(maybe_create=False, backend=_AUTO_BACKEND)
|
||||
|
||||
if parallel is False:
|
||||
return None
|
||||
|
||||
if parallel is True:
|
||||
if _AUTO_BACKEND is None:
|
||||
_AUTO_BACKEND = _DEFAULT_BACKEND
|
||||
parallel = _AUTO_BACKEND
|
||||
|
||||
if isinstance(parallel, numbers.Integral):
|
||||
_AUTO_BACKEND = _DEFAULT_BACKEND
|
||||
return get_pool(
|
||||
n_workers=parallel, maybe_create=True, backend=_DEFAULT_BACKEND
|
||||
)
|
||||
|
||||
if parallel == "loky":
|
||||
return get_pool(maybe_create=True, backend="loky")
|
||||
|
||||
if parallel == "concurrent.futures":
|
||||
return get_pool(maybe_create=True, backend="concurrent.futures")
|
||||
|
||||
if parallel == "threads":
|
||||
return get_pool(maybe_create=True, backend="threads")
|
||||
|
||||
if parallel == "dask":
|
||||
_AUTO_BACKEND = "dask"
|
||||
return get_pool(maybe_create=True, backend="dask")
|
||||
|
||||
if parallel == "ray":
|
||||
_AUTO_BACKEND = "ray"
|
||||
return get_pool(maybe_create=True, backend="ray")
|
||||
|
||||
return parallel
|
||||
|
||||
|
||||
def set_parallel_backend(backend):
|
||||
"""Create a parallel pool of type ``backend`` which registers it as the
|
||||
default for ``'auto'`` parallel.
|
||||
"""
|
||||
return parse_parallel_arg(backend)
|
||||
|
||||
|
||||
def maybe_leave_pool(pool):
|
||||
"""Logic required for nested parallelism in dask.distributed."""
|
||||
if _infer_backend(pool) == "dask":
|
||||
return _maybe_leave_pool_dask()
|
||||
|
||||
|
||||
def maybe_rejoin_pool(is_worker, pool):
|
||||
"""Logic required for nested parallelism in dask.distributed."""
|
||||
if is_worker and _infer_backend(pool) == "dask":
|
||||
_rejoin_pool_dask()
|
||||
|
||||
|
||||
def submit(pool, fn, *args, **kwargs):
|
||||
"""Interface for submitting ``fn(*args, **kwargs)`` to ``pool``."""
|
||||
if _infer_backend(pool) == "dask":
|
||||
kwargs.setdefault("pure", False)
|
||||
return pool.submit(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def scatter(pool, data):
|
||||
"""Interface for maybe turning ``data`` into a remote object or reference."""
|
||||
if _infer_backend(pool) in ("dask", "ray"):
|
||||
return pool.scatter(data)
|
||||
return data
|
||||
|
||||
|
||||
def can_scatter(pool):
|
||||
"""Whether ``pool`` can make objects remote."""
|
||||
return _infer_backend(pool) in ("dask", "ray")
|
||||
|
||||
|
||||
def should_nest(pool):
|
||||
"""Given argument ``pool`` should we try nested parallelism."""
|
||||
if pool is None:
|
||||
return False
|
||||
backend = _infer_backend(pool)
|
||||
if backend in ("ray", "dask"):
|
||||
return backend
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------- loky ----------------------------------- #
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def get_loky_get_reusable_executor():
|
||||
try:
|
||||
from loky import get_reusable_executor
|
||||
except ImportError:
|
||||
from joblib.externals.loky import get_reusable_executor
|
||||
return get_reusable_executor
|
||||
|
||||
|
||||
# --------------------------- concurrent.futures ---------------------------- #
|
||||
|
||||
|
||||
class CachedProcessPoolExecutor:
|
||||
def __init__(self):
|
||||
self._pool = None
|
||||
self._n_workers = -1
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
def __call__(self, n_workers=None):
|
||||
if n_workers != self._n_workers:
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
self.shutdown()
|
||||
self._pool = ProcessPoolExecutor(n_workers)
|
||||
self._n_workers = n_workers
|
||||
return self._pool
|
||||
|
||||
def is_initialized(self):
|
||||
return self._pool is not None
|
||||
|
||||
def shutdown(self):
|
||||
if self._pool is not None:
|
||||
self._pool.shutdown()
|
||||
self._pool = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
ProcessPoolHandler = CachedProcessPoolExecutor()
|
||||
|
||||
|
||||
def _get_process_pool_cf(n_workers=None):
|
||||
return ProcessPoolHandler(n_workers)
|
||||
|
||||
|
||||
class CachedThreadPoolExecutor:
|
||||
def __init__(self):
|
||||
self._pool = None
|
||||
self._n_workers = -1
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
def __call__(self, n_workers=None):
|
||||
if n_workers != self._n_workers:
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
self.shutdown()
|
||||
self._pool = ThreadPoolExecutor(n_workers)
|
||||
self._n_workers = n_workers
|
||||
return self._pool
|
||||
|
||||
def is_initialized(self):
|
||||
return self._pool is not None
|
||||
|
||||
def shutdown(self):
|
||||
if self._pool is not None:
|
||||
self._pool.shutdown()
|
||||
self._pool = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
ThreadPoolHandler = CachedThreadPoolExecutor()
|
||||
|
||||
|
||||
def _get_thread_pool_cf(n_workers=None):
|
||||
return ThreadPoolHandler(n_workers)
|
||||
|
||||
|
||||
# ---------------------------------- DASK ----------------------------------- #
|
||||
|
||||
|
||||
def _get_pool_dask(n_workers=None, maybe_create=False):
|
||||
"""Maybe get an existing or create a new dask.distrbuted client.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_workers : None or int, optional
|
||||
The number of workers to request if creating a new client.
|
||||
maybe_create : bool, optional
|
||||
Whether to create an new local cluster and client if no existing client
|
||||
is found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None or dask.distributed.Client
|
||||
"""
|
||||
try:
|
||||
from dask.distributed import get_client
|
||||
except ImportError:
|
||||
if not maybe_create:
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
client = get_client()
|
||||
except ValueError:
|
||||
if not maybe_create:
|
||||
return None
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from dask.distributed import Client, LocalCluster
|
||||
|
||||
local_directory = tempfile.mkdtemp()
|
||||
lc = LocalCluster(
|
||||
n_workers=n_workers,
|
||||
threads_per_worker=1,
|
||||
local_directory=local_directory,
|
||||
memory_limit=0,
|
||||
)
|
||||
client = Client(lc)
|
||||
|
||||
warnings.warn(
|
||||
"Parallel specified but no existing global dask client found... "
|
||||
"created one (with {} workers).".format(get_n_workers(client))
|
||||
)
|
||||
|
||||
@atexit.register
|
||||
def delete_local_dask_directory():
|
||||
shutil.rmtree(local_directory, ignore_errors=True)
|
||||
|
||||
if n_workers is not None:
|
||||
current_n_workers = get_n_workers(client)
|
||||
if n_workers != current_n_workers:
|
||||
warnings.warn(
|
||||
"Found existing client (with {} workers which) doesn't match "
|
||||
"the requested {}... using it instead.".format(
|
||||
current_n_workers, n_workers
|
||||
)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def _maybe_leave_pool_dask():
|
||||
try:
|
||||
from dask.distributed import secede
|
||||
|
||||
secede() # for nested parallelism
|
||||
is_dask_worker = True
|
||||
except (ImportError, ValueError):
|
||||
is_dask_worker = False
|
||||
return is_dask_worker
|
||||
|
||||
|
||||
def _rejoin_pool_dask():
|
||||
from dask.distributed import rejoin
|
||||
|
||||
rejoin()
|
||||
|
||||
|
||||
# ----------------------------------- RAY ----------------------------------- #
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_ray():
|
||||
""" """
|
||||
import ray
|
||||
|
||||
return ray
|
||||
|
||||
|
||||
class RayFuture:
|
||||
"""Basic ``concurrent.futures`` like future wrapping a ray ``ObjectRef``."""
|
||||
|
||||
__slots__ = ("_obj", "_cancelled")
|
||||
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
self._cancelled = False
|
||||
|
||||
def result(self, timeout=None):
|
||||
return get_ray().get(self._obj, timeout=timeout)
|
||||
|
||||
def done(self):
|
||||
return self._cancelled or bool(
|
||||
get_ray().wait([self._obj], timeout=0)[0]
|
||||
)
|
||||
|
||||
def cancel(self):
|
||||
get_ray().cancel(self._obj)
|
||||
self._cancelled = True
|
||||
|
||||
|
||||
def _unpack_futures_tuple(x):
|
||||
return tuple(map(_unpack_futures, x))
|
||||
|
||||
|
||||
def _unpack_futures_list(x):
|
||||
return list(map(_unpack_futures, x))
|
||||
|
||||
|
||||
def _unpack_futures_dict(x):
|
||||
return {k: _unpack_futures(v) for k, v in x.items()}
|
||||
|
||||
|
||||
def _unpack_futures_identity(x):
|
||||
return x
|
||||
|
||||
|
||||
_unpack_dispatch = collections.defaultdict(
|
||||
lambda: _unpack_futures_identity,
|
||||
{
|
||||
RayFuture: operator.attrgetter("_obj"),
|
||||
tuple: _unpack_futures_tuple,
|
||||
list: _unpack_futures_list,
|
||||
dict: _unpack_futures_dict,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _unpack_futures(x):
|
||||
"""Allows passing futures by reference - takes e.g. args and kwargs and
|
||||
replaces all ``RayFuture`` objects with their underyling ``ObjectRef``
|
||||
within all nested tuples, lists and dicts.
|
||||
|
||||
[Subclassing ``ObjectRef`` might avoid needing this.]
|
||||
"""
|
||||
return _unpack_dispatch[x.__class__](x)
|
||||
|
||||
|
||||
@functools.lru_cache(2**14)
|
||||
def get_remote_fn(fn, **remote_opts):
|
||||
"""Cached retrieval of remote function."""
|
||||
ray = get_ray()
|
||||
if remote_opts:
|
||||
return ray.remote(**remote_opts)(fn)
|
||||
return ray.remote(fn)
|
||||
|
||||
|
||||
@functools.lru_cache(2**14)
|
||||
def get_fn_as_remote_object(fn):
|
||||
ray = get_ray()
|
||||
return ray.put(fn)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_deploy(**remote_opts):
|
||||
"""Alternative for 'non-function' callables - e.g. partial
|
||||
functions - pass the callable object too.
|
||||
"""
|
||||
ray = get_ray()
|
||||
|
||||
def deploy(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if remote_opts:
|
||||
return ray.remote(**remote_opts)(deploy)
|
||||
return ray.remote(deploy)
|
||||
|
||||
|
||||
class RayExecutor:
|
||||
"""Basic ``concurrent.futures`` like interface using ``ray``."""
|
||||
|
||||
def __init__(self, *args, default_remote_opts=None, **kwargs):
|
||||
ray = get_ray()
|
||||
if not ray.is_initialized():
|
||||
ray.init(*args, **kwargs)
|
||||
|
||||
self.default_remote_opts = (
|
||||
{} if default_remote_opts is None else dict(default_remote_opts)
|
||||
)
|
||||
|
||||
def _maybe_inject_remote_opts(self, remote_opts=None):
|
||||
"""Return the default remote options, possibly overriding some with
|
||||
those supplied by a ``submit call``.
|
||||
"""
|
||||
ropts = self.default_remote_opts
|
||||
if remote_opts is not None:
|
||||
ropts = {**ropts, **remote_opts}
|
||||
return ropts
|
||||
|
||||
def submit(self, fn, *args, pure=False, remote_opts=None, **kwargs):
|
||||
"""Remotely run ``fn(*args, **kwargs)``, returning a ``RayFuture``."""
|
||||
# want to pass futures by reference
|
||||
args = _unpack_futures_tuple(args)
|
||||
kwargs = _unpack_futures_dict(kwargs)
|
||||
|
||||
ropts = self._maybe_inject_remote_opts(remote_opts)
|
||||
|
||||
# this is the same test ray uses to accept functions
|
||||
if inspect.isfunction(fn):
|
||||
# can use the faster cached remote function
|
||||
obj = get_remote_fn(fn, **ropts).remote(*args, **kwargs)
|
||||
else:
|
||||
fn_obj = get_fn_as_remote_object(fn)
|
||||
obj = get_deploy(**ropts).remote(fn_obj, *args, **kwargs)
|
||||
|
||||
return RayFuture(obj)
|
||||
|
||||
def map(self, func, *iterables, remote_opts=None):
|
||||
"""Remote map ``func`` over arguments ``iterables``."""
|
||||
ropts = self._maybe_inject_remote_opts(remote_opts)
|
||||
remote_fn = get_remote_fn(func, **ropts)
|
||||
objs = tuple(map(remote_fn.remote, *iterables))
|
||||
ray = get_ray()
|
||||
return map(ray.get, objs)
|
||||
|
||||
def scatter(self, data):
|
||||
"""Push ``data`` into the distributed store, returning an ``ObjectRef``
|
||||
that can be supplied to ``submit`` calls for example.
|
||||
"""
|
||||
ray = get_ray()
|
||||
return ray.put(data)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the parent ray cluster, this ``RayExecutor`` instance
|
||||
itself does not need any cleanup.
|
||||
"""
|
||||
get_ray().shutdown()
|
||||
|
||||
|
||||
_RAY_EXECUTOR = None
|
||||
|
||||
|
||||
def _get_pool_ray(n_workers=None, maybe_create=False):
|
||||
"""Maybe get an existing or create a new RayExecutor, thus initializing,
|
||||
ray.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_workers : None or int, optional
|
||||
The number of workers to request if creating a new client.
|
||||
maybe_create : bool, optional
|
||||
Whether to create initialize ray and return a RayExecutor if not
|
||||
initialized already.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None or RayExecutor
|
||||
"""
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
if not maybe_create:
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
global _RAY_EXECUTOR
|
||||
|
||||
if (_RAY_EXECUTOR is None) or (not ray.is_initialized()):
|
||||
if not maybe_create:
|
||||
return None
|
||||
_RAY_EXECUTOR = RayExecutor(num_cpus=n_workers)
|
||||
|
||||
if n_workers is not None:
|
||||
current_n_workers = get_n_workers(_RAY_EXECUTOR)
|
||||
if n_workers != current_n_workers:
|
||||
warnings.warn(
|
||||
"Found initialized ray (with {} workers which) doesn't match "
|
||||
"the requested {}... sticking with old number.".format(
|
||||
current_n_workers, n_workers
|
||||
)
|
||||
)
|
||||
|
||||
return _RAY_EXECUTOR
|
||||
Reference in New Issue
Block a user