Files
unifolm-world-model-action/psnr_score_for_challenge.py
2026-02-08 03:29:15 +00:00

90 lines
2.8 KiB
Python

import os
import glob
import numpy as np
import json
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from tqdm import tqdm
from moviepy.video.io.VideoFileClip import VideoFileClip
import PIL.Image
def calculate_psnr(img1, img2):
mse = np.mean((img1.astype(np.float64) - img2.astype(np.float64)) ** 2)
if mse == 0:
return float('inf')
max_pixel = 255.0
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
return psnr
def process_video_psnr(gt_path, pred_path):
try:
clip_gt = VideoFileClip(gt_path)
clip_pred = VideoFileClip(pred_path)
fps = min(clip_gt.fps, clip_pred.fps)
duration = min(clip_gt.duration, clip_pred.duration)
time_points = np.arange(0, duration, 1.0 / fps)
video_psnrs = []
for t in time_points:
frame_gt = clip_gt.get_frame(t)
frame_pred = clip_pred.get_frame(t)
img_gt = PIL.Image.fromarray(frame_gt).resize((256, 256), PIL.Image.Resampling.BILINEAR)
img_pred = PIL.Image.fromarray(frame_pred).resize((256, 256), PIL.Image.Resampling.BILINEAR)
psnr = calculate_psnr(np.array(img_gt), np.array(img_pred))
video_psnrs.append(psnr)
clip_gt.close()
clip_pred.close()
return np.mean(video_psnrs) if video_psnrs else 0.0
except Exception as e:
print(f"Error processing {os.path.basename(gt_path)}: {e}")
return None
def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--gt_video', type=str, required=True, help='path to reference videos')
parser.add_argument('--pred_video', type=str, required=True, help='path to pred videos')
parser.add_argument('--output_file', type=str, default=None, help='path to output file')
args = parser.parse_args()
if not os.path.exists(args.gt_video):
print(f"Error: GT video not found at {args.gt_video}")
return
if not os.path.exists(args.pred_video):
print(f"Error: Pred video not found at {args.pred_video}")
return
print(f"Comparing:\nRef: {args.gt_video}\nPred: {args.pred_video}")
v_psnr = process_video_psnr(args.gt_video, args.pred_video)
if v_psnr is not None:
print("-" * 30)
print(f"Video PSNR: {v_psnr:.4f} dB")
print("-" * 30)
if args.output_file:
result = {
"gt_video": args.gt_video,
"pred_video": args.pred_video,
"psnr": v_psnr
}
with open(args.output_file, 'w') as f:
json.dump(result, f, indent=4)
print(f"Result saved to {args.output_file}")
else:
print("Failed to calculate PSNR.")
if __name__ == '__main__':
main()