-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_jittor.py
69 lines (47 loc) · 1.94 KB
/
inference_jittor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os
import jittor as jt
from omegaconf import OmegaConf
from tqdm import tqdm
from model_jittor.dataset import InferenceDataset
from model_jittor.ldm.ddim import DDIMSampler
from model_jittor.ldm.ddpm import LatentDiffusion
from utils import to_pil_image
def init_and_run():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--save', type=str, default='./results/')
parser.add_argument('-n', '--name', type=str, default='res67')
parser.add_argument('-b', '--batch_size', type=int, default=4)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
args.save = os.path.join(args.save, args.name)
os.makedirs(args.save, exist_ok=True)
jt.set_global_seed(args.seed)
cfg = OmegaConf.load('./configs/inference.yaml')
main(args, cfg)
def main(args, cfg):
# create model
model = LatentDiffusion(**cfg.model)
# create ddim sampler
sampler = DDIMSampler(model, use_ema=False)
# load dataset
dataset = InferenceDataset(segmentation_root='/nas/landscape/test_B/labels')
data_loader = dataset.set_attrs(batch_size=args.batch_size)
for i, (segs, names) in enumerate(tqdm(data_loader)):
# b 29 384 512 -> b 3 96, 128
condition = model.cond_stage_model(segs)
samples_ddim, _ = sampler.sample(
num_steps=200,
condition=condition,
verbose=False,
)
# b 3 96, 128 -> b 3 384 512
# samples_ddim, _, _ = model.first_stage_model.quantize(samples_ddim)
x_samples_ddim = model.decode_first_stage(samples_ddim)
predicted_image = (x_samples_ddim+1.0)/2.0
# predicted_image = jt.clamp((x_samples_ddim+1.0)/2.0, min_v=0.0, max_v=1.0)
for i, name in enumerate(names):
to_pil_image(predicted_image[i]).save(f'{args.save}/{name}.jpg')
if __name__ == '__main__':
jt.flags.use_cuda=True
init_and_run()