-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
45 lines (36 loc) · 1.08 KB
/
sample.py
File metadata and controls
45 lines (36 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from model.SimpleDiffusion import SimpleDiffusion
from dataset import save_image, label2num
import settings
import os
def main():
prompt = ["frog"] * 64
checkpoint_dir = os.path.join(
settings.save_dir,
"MASKINGUNetNoAtt",
"lightning_logs",
"version_0",
"checkpoints",
"epoch=99-step=78200.ckpt",
)
labels = [label2num(pro) for pro in prompt]
labels = torch.Tensor(labels).to(settings.device)
model = (
SimpleDiffusion.load_from_checkpoint(
checkpoint_dir,
timesteps=1000,
class_rate=0.9,
MASKING=True,
ATTENTION=False,
)
.to(settings.device)
.eval()
)
save_dir = os.path.join(settings.save_dir, "samples")
results = model.sample(labels, 32, 3).detach().cpu().numpy()
# results = (results.clip(-1, 1) + 1) / 2
results = results.clip(0, 1)
for idx, res in enumerate(results):
save_image(os.path.join(save_dir, f"{idx}.png"), res.transpose(1, 2, 0))
if __name__ == "__main__":
main()