tf32推理
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -129,4 +129,6 @@ Data/utils.py
|
||||
Experiment/checkpoint
|
||||
Experiment/log
|
||||
|
||||
*.ckpt
|
||||
*.ckpt
|
||||
|
||||
*.0
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
@@ -16,6 +16,9 @@ from collections import OrderedDict
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
"""
|
||||
|
||||
@@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse
|
||||
from typing import Any, Dict, Optional, Tuple, List
|
||||
from datetime import datetime
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from collections import deque
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
|
||||
@@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user