跳过模型加载
This commit is contained in:
@@ -462,24 +462,43 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# Load config
|
||||
# Load config (always needed for data setup)
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Build unnomalizer
|
||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False,
|
||||
mmap=True)
|
||||
model.eval()
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + load checkpoint ----
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
model = model.cuda(gpu_no)
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Save prepared model for fast loading next time
|
||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Run over data
|
||||
|
||||
Reference in New Issue
Block a user