[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-08-26 09:23:25 +00:00
parent 65a04c32fa
commit f2f84e79fd

View File

@@ -220,10 +220,10 @@ def reduce_result(result, comm, method="MPI", root=0):
elif method == "NCCL":
stream_ptr = cp.cuda.get_current_stream().ptr
if result.dtype == cp.complex128:
count = result.size * 2 # complex128 has 2 float64 numbers
count = result.size * 2 # complex128 has 2 float64 numbers
nccl_type = nccl.NCCL_FLOAT64
elif result.dtype == cp.complex64:
count = result.size * 2 # complex64 has 2 float32 numbers
count = result.size * 2 # complex64 has 2 float32 numbers
nccl_type = nccl.NCCL_FLOAT32
else:
raise TypeError(f"Unsupported dtype for NCCL reduce: {result.dtype}")