简化代码;加入.venv下内容
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
jaunatisblue
2026-05-18 02:47:40 +08:00
parent ef3d7e9ee6
commit 28080dff1d
15 changed files with 17145 additions and 118 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,583 @@
"""Interface for parallelism."""
import atexit
import collections
import functools
import importlib
import inspect
import numbers
import operator
import warnings
_AUTO_BACKEND = None
# check for loky, joblib (vendors loky), then default to concurrent.futures
have_loky = importlib.util.find_spec("loky") is not None
have_joblib = importlib.util.find_spec("joblib") is not None
if have_loky or have_joblib:
_DEFAULT_BACKEND = "loky"
else:
_DEFAULT_BACKEND = "concurrent.futures"
@functools.lru_cache(None)
def choose_default_num_workers():
import os
if "COTENGRA_NUM_WORKERS" in os.environ:
return int(os.environ["COTENGRA_NUM_WORKERS"])
if "OMP_NUM_THREADS" in os.environ:
return int(os.environ["OMP_NUM_THREADS"])
return os.cpu_count()
def get_pool(n_workers=None, maybe_create=False, backend=None):
"""Get a parallel pool."""
if backend is None:
backend = _DEFAULT_BACKEND
if backend == "dask":
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)
if backend == "ray":
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)
# above backends are distributed, don't specify n_workers
if n_workers is None:
n_workers = choose_default_num_workers()
if backend == "loky":
get_reusable_executor = get_loky_get_reusable_executor()
return get_reusable_executor(max_workers=n_workers)
if backend == "concurrent.futures":
return _get_process_pool_cf(n_workers=n_workers)
if backend == "threads":
return _get_thread_pool_cf(n_workers=n_workers)
@functools.lru_cache(None)
def _infer_backed_cached(pool_class):
if pool_class.__name__ == "RayExecutor":
return "ray"
path = pool_class.__module__.split(".")
if path[0] == "concurrent":
return "concurrent.futures"
if path[0] == "joblib":
return "loky"
if path[0] == "distributed":
return "dask"
return path[0]
def _infer_backend(pool):
"""Return the backend type of ``pool`` - cached for speed."""
return _infer_backed_cached(pool.__class__)
def get_n_workers(pool=None):
"""Extract how many workers our pool has (mostly for working out how many
tasks to pre-dispatch).
"""
if pool is None:
pool = get_pool()
try:
return pool._max_workers
except AttributeError:
pass
backend = _infer_backend(pool)
if backend == "dask":
workers = pool.scheduler_info(n_workers=-1)["workers"]
return sum(int(w.get("nthreads", 1) or 1) for w in workers.values())
if backend == "ray":
while True:
try:
return int(get_ray().available_resources()["CPU"])
except KeyError:
import time
time.sleep(1e-3)
if backend == "mpi4py":
from mpi4py import MPI
return MPI.COMM_WORLD.size
raise ValueError(f"Can't find number of workers in pool {pool}.")
def parse_parallel_arg(parallel):
""" """
global _AUTO_BACKEND
if parallel == "auto":
return get_pool(maybe_create=False, backend=_AUTO_BACKEND)
if parallel is False:
return None
if parallel is True:
if _AUTO_BACKEND is None:
_AUTO_BACKEND = _DEFAULT_BACKEND
parallel = _AUTO_BACKEND
if isinstance(parallel, numbers.Integral):
_AUTO_BACKEND = _DEFAULT_BACKEND
return get_pool(
n_workers=parallel, maybe_create=True, backend=_DEFAULT_BACKEND
)
if parallel == "loky":
return get_pool(maybe_create=True, backend="loky")
if parallel == "concurrent.futures":
return get_pool(maybe_create=True, backend="concurrent.futures")
if parallel == "threads":
return get_pool(maybe_create=True, backend="threads")
if parallel == "dask":
_AUTO_BACKEND = "dask"
return get_pool(maybe_create=True, backend="dask")
if parallel == "ray":
_AUTO_BACKEND = "ray"
return get_pool(maybe_create=True, backend="ray")
return parallel
def set_parallel_backend(backend):
"""Create a parallel pool of type ``backend`` which registers it as the
default for ``'auto'`` parallel.
"""
return parse_parallel_arg(backend)
def maybe_leave_pool(pool):
"""Logic required for nested parallelism in dask.distributed."""
if _infer_backend(pool) == "dask":
return _maybe_leave_pool_dask()
def maybe_rejoin_pool(is_worker, pool):
"""Logic required for nested parallelism in dask.distributed."""
if is_worker and _infer_backend(pool) == "dask":
_rejoin_pool_dask()
def submit(pool, fn, *args, **kwargs):
"""Interface for submitting ``fn(*args, **kwargs)`` to ``pool``."""
if _infer_backend(pool) == "dask":
kwargs.setdefault("pure", False)
return pool.submit(fn, *args, **kwargs)
def scatter(pool, data):
"""Interface for maybe turning ``data`` into a remote object or reference."""
if _infer_backend(pool) in ("dask", "ray"):
return pool.scatter(data)
return data
def can_scatter(pool):
"""Whether ``pool`` can make objects remote."""
return _infer_backend(pool) in ("dask", "ray")
def should_nest(pool):
"""Given argument ``pool`` should we try nested parallelism."""
if pool is None:
return False
backend = _infer_backend(pool)
if backend in ("ray", "dask"):
return backend
return False
# ---------------------------------- loky ----------------------------------- #
@functools.lru_cache(1)
def get_loky_get_reusable_executor():
try:
from loky import get_reusable_executor
except ImportError:
from joblib.externals.loky import get_reusable_executor
return get_reusable_executor
# --------------------------- concurrent.futures ---------------------------- #
class CachedProcessPoolExecutor:
def __init__(self):
self._pool = None
self._n_workers = -1
atexit.register(self.shutdown)
def __call__(self, n_workers=None):
if n_workers != self._n_workers:
from concurrent.futures import ProcessPoolExecutor
self.shutdown()
self._pool = ProcessPoolExecutor(n_workers)
self._n_workers = n_workers
return self._pool
def is_initialized(self):
return self._pool is not None
def shutdown(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None
def __del__(self):
self.shutdown()
ProcessPoolHandler = CachedProcessPoolExecutor()
def _get_process_pool_cf(n_workers=None):
return ProcessPoolHandler(n_workers)
class CachedThreadPoolExecutor:
def __init__(self):
self._pool = None
self._n_workers = -1
atexit.register(self.shutdown)
def __call__(self, n_workers=None):
if n_workers != self._n_workers:
from concurrent.futures import ThreadPoolExecutor
self.shutdown()
self._pool = ThreadPoolExecutor(n_workers)
self._n_workers = n_workers
return self._pool
def is_initialized(self):
return self._pool is not None
def shutdown(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None
def __del__(self):
self.shutdown()
ThreadPoolHandler = CachedThreadPoolExecutor()
def _get_thread_pool_cf(n_workers=None):
return ThreadPoolHandler(n_workers)
# ---------------------------------- DASK ----------------------------------- #
def _get_pool_dask(n_workers=None, maybe_create=False):
"""Maybe get an existing or create a new dask.distrbuted client.
Parameters
----------
n_workers : None or int, optional
The number of workers to request if creating a new client.
maybe_create : bool, optional
Whether to create an new local cluster and client if no existing client
is found.
Returns
-------
None or dask.distributed.Client
"""
try:
from dask.distributed import get_client
except ImportError:
if not maybe_create:
return None
else:
raise
try:
client = get_client()
except ValueError:
if not maybe_create:
return None
import shutil
import tempfile
from dask.distributed import Client, LocalCluster
local_directory = tempfile.mkdtemp()
lc = LocalCluster(
n_workers=n_workers,
threads_per_worker=1,
local_directory=local_directory,
memory_limit=0,
)
client = Client(lc)
warnings.warn(
"Parallel specified but no existing global dask client found... "
"created one (with {} workers).".format(get_n_workers(client))
)
@atexit.register
def delete_local_dask_directory():
shutil.rmtree(local_directory, ignore_errors=True)
if n_workers is not None:
current_n_workers = get_n_workers(client)
if n_workers != current_n_workers:
warnings.warn(
"Found existing client (with {} workers which) doesn't match "
"the requested {}... using it instead.".format(
current_n_workers, n_workers
)
)
return client
def _maybe_leave_pool_dask():
try:
from dask.distributed import secede
secede() # for nested parallelism
is_dask_worker = True
except (ImportError, ValueError):
is_dask_worker = False
return is_dask_worker
def _rejoin_pool_dask():
from dask.distributed import rejoin
rejoin()
# ----------------------------------- RAY ----------------------------------- #
@functools.lru_cache(None)
def get_ray():
""" """
import ray
return ray
class RayFuture:
"""Basic ``concurrent.futures`` like future wrapping a ray ``ObjectRef``."""
__slots__ = ("_obj", "_cancelled")
def __init__(self, obj):
self._obj = obj
self._cancelled = False
def result(self, timeout=None):
return get_ray().get(self._obj, timeout=timeout)
def done(self):
return self._cancelled or bool(
get_ray().wait([self._obj], timeout=0)[0]
)
def cancel(self):
get_ray().cancel(self._obj)
self._cancelled = True
def _unpack_futures_tuple(x):
return tuple(map(_unpack_futures, x))
def _unpack_futures_list(x):
return list(map(_unpack_futures, x))
def _unpack_futures_dict(x):
return {k: _unpack_futures(v) for k, v in x.items()}
def _unpack_futures_identity(x):
return x
_unpack_dispatch = collections.defaultdict(
lambda: _unpack_futures_identity,
{
RayFuture: operator.attrgetter("_obj"),
tuple: _unpack_futures_tuple,
list: _unpack_futures_list,
dict: _unpack_futures_dict,
},
)
def _unpack_futures(x):
"""Allows passing futures by reference - takes e.g. args and kwargs and
replaces all ``RayFuture`` objects with their underyling ``ObjectRef``
within all nested tuples, lists and dicts.
[Subclassing ``ObjectRef`` might avoid needing this.]
"""
return _unpack_dispatch[x.__class__](x)
@functools.lru_cache(2**14)
def get_remote_fn(fn, **remote_opts):
"""Cached retrieval of remote function."""
ray = get_ray()
if remote_opts:
return ray.remote(**remote_opts)(fn)
return ray.remote(fn)
@functools.lru_cache(2**14)
def get_fn_as_remote_object(fn):
ray = get_ray()
return ray.put(fn)
@functools.lru_cache(None)
def get_deploy(**remote_opts):
"""Alternative for 'non-function' callables - e.g. partial
functions - pass the callable object too.
"""
ray = get_ray()
def deploy(fn, *args, **kwargs):
return fn(*args, **kwargs)
if remote_opts:
return ray.remote(**remote_opts)(deploy)
return ray.remote(deploy)
class RayExecutor:
"""Basic ``concurrent.futures`` like interface using ``ray``."""
def __init__(self, *args, default_remote_opts=None, **kwargs):
ray = get_ray()
if not ray.is_initialized():
ray.init(*args, **kwargs)
self.default_remote_opts = (
{} if default_remote_opts is None else dict(default_remote_opts)
)
def _maybe_inject_remote_opts(self, remote_opts=None):
"""Return the default remote options, possibly overriding some with
those supplied by a ``submit call``.
"""
ropts = self.default_remote_opts
if remote_opts is not None:
ropts = {**ropts, **remote_opts}
return ropts
def submit(self, fn, *args, pure=False, remote_opts=None, **kwargs):
"""Remotely run ``fn(*args, **kwargs)``, returning a ``RayFuture``."""
# want to pass futures by reference
args = _unpack_futures_tuple(args)
kwargs = _unpack_futures_dict(kwargs)
ropts = self._maybe_inject_remote_opts(remote_opts)
# this is the same test ray uses to accept functions
if inspect.isfunction(fn):
# can use the faster cached remote function
obj = get_remote_fn(fn, **ropts).remote(*args, **kwargs)
else:
fn_obj = get_fn_as_remote_object(fn)
obj = get_deploy(**ropts).remote(fn_obj, *args, **kwargs)
return RayFuture(obj)
def map(self, func, *iterables, remote_opts=None):
"""Remote map ``func`` over arguments ``iterables``."""
ropts = self._maybe_inject_remote_opts(remote_opts)
remote_fn = get_remote_fn(func, **ropts)
objs = tuple(map(remote_fn.remote, *iterables))
ray = get_ray()
return map(ray.get, objs)
def scatter(self, data):
"""Push ``data`` into the distributed store, returning an ``ObjectRef``
that can be supplied to ``submit`` calls for example.
"""
ray = get_ray()
return ray.put(data)
def shutdown(self):
"""Shutdown the parent ray cluster, this ``RayExecutor`` instance
itself does not need any cleanup.
"""
get_ray().shutdown()
_RAY_EXECUTOR = None
def _get_pool_ray(n_workers=None, maybe_create=False):
"""Maybe get an existing or create a new RayExecutor, thus initializing,
ray.
Parameters
----------
n_workers : None or int, optional
The number of workers to request if creating a new client.
maybe_create : bool, optional
Whether to create initialize ray and return a RayExecutor if not
initialized already.
Returns
-------
None or RayExecutor
"""
try:
import ray
except ImportError:
if not maybe_create:
return None
else:
raise
global _RAY_EXECUTOR
if (_RAY_EXECUTOR is None) or (not ray.is_initialized()):
if not maybe_create:
return None
_RAY_EXECUTOR = RayExecutor(num_cpus=n_workers)
if n_workers is not None:
current_n_workers = get_n_workers(_RAY_EXECUTOR)
if n_workers != current_n_workers:
warnings.warn(
"Found initialized ray (with {} workers which) doesn't match "
"the requested {}... sticking with old number.".format(
current_n_workers, n_workers
)
)
return _RAY_EXECUTOR