diff --git a/.gitignore b/.gitignore index 2ae054d..781de6e 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,6 @@ Data/utils.py Experiment/checkpoint Experiment/log -*.ckpt \ No newline at end of file +*.ckpt + +*.0 \ No newline at end of file diff --git a/configs/inference/world_model_interaction.yaml b/configs/inference/world_model_interaction.yaml index da709e0..a1e115a 100644 --- a/configs/inference/world_model_interaction.yaml +++ b/configs/inference/world_model_interaction.yaml @@ -222,7 +222,7 @@ data: test: target: unifolm_wma.data.wma_data.WMAData params: - data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts' + data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts' video_length: ${model.params.wma_config.params.temporal_length} frame_stride: 2 load_raw_resolution: True diff --git a/scripts/evaluation/base_model_inference.py b/scripts/evaluation/base_model_inference.py index 42945a7..4ef619a 100644 --- a/scripts/evaluation/base_model_inference.py +++ b/scripts/evaluation/base_model_inference.py @@ -16,6 +16,9 @@ from collections import OrderedDict from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: """ diff --git a/scripts/evaluation/real_eval_server.py b/scripts/evaluation/real_eval_server.py index d780b5d..4ae5a09 100644 --- a/scripts/evaluation/real_eval_server.py +++ b/scripts/evaluation/real_eval_server.py @@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse from typing import Any, Dict, Optional, Tuple, List from datetime import datetime +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.models.samplers.ddim import DDIMSampler diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index ccc9747..562d357 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -18,6 +18,9 @@ from collections import OrderedDict from torch import nn from eval_utils import populate_queues, log_to_tensorboard from collections import deque + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True from torch import Tensor from torch.utils.tensorboard import SummaryWriter from PIL import Image diff --git a/scripts/trainer.py b/scripts/trainer.py index 87c6820..2c8e19f 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + def get_parser(**parser_kwargs): parser = argparse.ArgumentParser(**parser_kwargs)