forked from seidj/LieAD
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
89 lines (75 loc) · 2.44 KB
/
models.py
File metadata and controls
89 lines (75 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import jax
import os
import sys
from functools import partial
from tqdm import trange
import torch.utils.data as data
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax.experimental.jet import jet
from jax import random, vmap, jit, jacrev, grad
from jax.lax import scan
import flax
from flax import linen as nn
from flax.training import train_state
import os
import numpy as np
import optax
from pyDOE import lhs
from lie_derivs import make_derivs_func, iterated_brackets
# sample from multivariate uniform distribution on hyper-rectangle
def multivariate_uniform(key, dim, N, bounds):
dilations = np.asarray([interval[1]-interval[0] for interval in bounds])
lbs = np.asarray([interval[0] for interval in bounds])
data = random.uniform(key, minval=0, maxval=1, shape=(dim, N))
data = dilations[:,None]*data # check broadcasting
data = data + lbs[:,None] # check broadcasting
return data.T
class MLP(nn.Module):
num_hidden: int # Neurons per hidden layer
num_layers: int # Number of hidden layers
num_outputs: int # Output dimension
@nn.compact
def __call__(self, x):
# Hidden Layers
for _ in range(self.num_layers):
x = nn.Dense(features=self.num_hidden)(x)
x = nn.gelu(x)
# Final dense layer
x = nn.Dense(features=self.num_outputs)(x)
return x
class FL_PINN(data.Dataset):
def __init__(self, f, g, dim, order, size, bounds, key, batch_size=64):
super().__init__()
self.f = f
self.g = g
self.dim = dim
self.bounds = bounds
self.size = size
self.key = key
self.batch_size = batch_size
self.order = order
print('Generating collocation points')
self.generate_inputs()
print('Generating brackets at points')
self.generate_lie_brackets()
print('Brackets done')
def generate_inputs(self):
self.key, subkey = random.split(self.key)
# data = multivariate_uniform(self.key, self.dim, self.size, self.bounds)
data = lhs(self.dim, self.size)
self.data = jnp.asarray(data)
def generate_lie_brackets(self):
brackets_fun = lambda x: iterated_brackets(self.f, self.g, x, self.order, self.dim)
brackets = vmap(brackets_fun)(self.data)
self.brackets = brackets
def __len__(self):
return self.size
def __getitem__(self, idx):
self.key, subkey = random.split(self.key)
batch = self.__select_batch(subkey)
return batch
@partial(jit, static_argnums=(0,))
def __select_batch(self, key):
idx = random.choice(key, self.size, (self.batch_size,), replace=False)
return (self.data[idx], self.brackets[idx])