-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathMatNNGradTypes.hs
More file actions
53 lines (42 loc) · 1.79 KB
/
MatNNGradTypes.hs
File metadata and controls
53 lines (42 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module MatNNGradTypes
( Layer(..)
, NN(..)
, Grads(..)
, Batch(..)
, VAE(..)
, AdamOptim(..)
, NNAdamOptim(..)
, VAEAdamOptim(..)
, DataInfo(..)
, TrainInfo(..)
, TrainStats(..)
) where
import Numeric.LinearAlgebra
import Data.Typeable
import Control.DeepSeq
newtype Layer = Layer {getLayer :: (Matrix R)} deriving (Show, NFData)
newtype NN = WeightMatList {getNN :: [Layer]} deriving (Show, NFData)
newtype Grads = GradMatList {getGrads :: [Matrix R]} deriving (Show, NFData)
newtype AdamOptim = AdamOptimParams {getAdamOptim :: (Matrix R, Matrix R, Int, Double, Double)} deriving (Show, NFData)
newtype NNAdamOptim = NNAdamOptimParams {getNNAdamOptim :: [AdamOptim]} deriving (Show, NFData)
newtype VAEAdamOptim = VAEAdamOptimParams {getVAEAdamOptim :: (NNAdamOptim, NNAdamOptim, Double)} deriving (Show, NFData)
newtype Batch = Batch {getBatch :: Matrix R} deriving (Show, NFData)
newtype VAE = VAE {getVAE :: (NN, NN)} deriving (Show, NFData)
data DataInfo = DataInfo { n_input :: Int
, n_tot_samples :: Int
, prefix :: String
, data_dir :: String
} deriving (Show)
data TrainInfo = TrainInfo { batch_size :: Int
, batches_per_epoch :: Int
, n_epochs :: Int
, lr :: Double
, beta_KL_max :: Double
, beta_KL_method :: String
} deriving (Show)
data TrainStats = TrainStats { beta_KL :: [Double]
, losses_kl :: [Double]
, losses_recon :: [Double]
, losses_total :: [Double]
} deriving (Show)