Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
"""Slice an existing saved cotengra tree without re-running path search."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
from qibotn.parallel import contraction_tree_costs
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input", help="Input pickle saved by --tn-save-tree.")
|
|
parser.add_argument("output", help="Output pickle path.")
|
|
parser.add_argument("--term", type=int, default=0)
|
|
parser.add_argument("--target-slices", type=int, default=2)
|
|
parser.add_argument("--max-repeats", type=int, default=64)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
args = parser.parse_args()
|
|
|
|
input_path = Path(args.input)
|
|
output_path = Path(args.output)
|
|
with input_path.open("rb") as f:
|
|
payload = pickle.load(f)
|
|
|
|
trees = payload["trees"] if isinstance(payload, dict) else payload
|
|
if not isinstance(trees, (list, tuple)):
|
|
trees = [trees]
|
|
tree = trees[args.term]
|
|
|
|
print("original", contraction_tree_costs(tree), flush=True)
|
|
sliced = tree.slice(
|
|
target_slices=args.target_slices,
|
|
max_repeats=args.max_repeats,
|
|
seed=args.seed,
|
|
)
|
|
print("sliced", contraction_tree_costs(sliced), flush=True)
|
|
print(f"sliced_inds={sliced.sliced_inds}", flush=True)
|
|
|
|
new_trees = list(trees)
|
|
new_trees[args.term] = sliced
|
|
|
|
if isinstance(payload, dict):
|
|
out_payload = dict(payload)
|
|
out_payload["trees"] = new_trees
|
|
out_payload["costs"] = [contraction_tree_costs(t) for t in new_trees]
|
|
out_payload["nterms"] = len(new_trees)
|
|
else:
|
|
out_payload = new_trees
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("wb") as f:
|
|
pickle.dump(out_payload, f)
|
|
print(f"saved {output_path}", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|