forked from GBATZOLIS/conditional_score_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_test.py
More file actions
29 lines (22 loc) · 736 Bytes
/
dataset_test.py
File metadata and controls
29 lines (22 loc) · 736 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
from SyntheticDataset import SyntheticDataModule
from SyntheticDataset import scatter_plot
from model_lightning import SdeGenerativeModel
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=True)
flags.mark_flags_as_required(["config"])
def main(argv):
config = FLAGS.config
data = SyntheticDataModule(config)
data.setup()
loader = data.train_dataloader()
batch = next(iter(loader))
scatter_plot(batch, save=True)
model = SdeGenerativeModel(config)
x = model.sample()
scatter_plot(x, save=True)
if __name__ == "__main__":
app.run(main)