Skip to content

Commit e5b220b

Browse files
committed
update week 11
1 parent b10b8b0 commit e5b220b

15 files changed

+1052
-336
lines changed
0 Bytes
Binary file not shown.

doc/pub/week11/ipynb/week11.ipynb

Lines changed: 479 additions & 332 deletions
Large diffs are not rendered by default.

doc/pub/week11/pdf/week11.pdf

7.78 KB
Binary file not shown.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{c+c1}{\PYGZsh{} Set up the conditional probability distribution for each dimension}
3+
\PYG{c+c1}{\PYGZsh{} For example, I can sample p(a | b) using sample\PYGZus{}for\PYGZus{}dim[0].}
4+
5+
\PYG{n}{univariate\PYGZus{}conditionals} \PYG{o}{=} \PYG{p}{[}
6+
\PYG{n}{get\PYGZus{}conditional\PYGZus{}dist}\PYG{p}{(}\PYG{n}{joint\PYGZus{}mu}\PYG{p}{,} \PYG{n}{joint\PYGZus{}cov}\PYG{p}{,} \PYG{n}{d}\PYG{p}{)}
7+
\PYG{k}{for} \PYG{n}{d} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n}{D}\PYG{p}{)}
8+
\PYG{p}{]}
9+
10+
\end{MintedVerbatim}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{get\PYGZus{}conditional\PYGZus{}dist}\PYG{p}{(}\PYG{n}{joint\PYGZus{}mu}\PYG{p}{,} \PYG{n}{joint\PYGZus{}cov}\PYG{p}{,} \PYG{n}{var\PYGZus{}index}\PYG{p}{)}\PYG{p}{:}
3+
\PYG{+w}{ }\PYG{l+s+sd}{\PYGZsq{}\PYGZsq{}\PYGZsq{}Returns the conditional distribution given the joint distribution and which variable}
4+
\PYG{l+s+sd}{ the conditional probability should use.}
5+
\PYG{l+s+sd}{ Right now this only works for 2\PYGZhy{}variable joint distributions.}
6+
7+
\PYG{l+s+sd}{ joint\PYGZus{}mu: joint distribution\PYGZsq{}s mu}
8+
\PYG{l+s+sd}{ joint\PYGZus{}cov: joint distribution\PYGZsq{}s covariance}
9+
\PYG{l+s+sd}{ var\PYGZus{}index: index of the variable in the joint distribution. Everything else will be}
10+
\PYG{l+s+sd}{ conditioned on. For example, if the joint distribution p(a, b, c) has mu [mu\PYGZus{}a, mu\PYGZus{}b, mu\PYGZus{}c],}
11+
\PYG{l+s+sd}{ to get p(c | a, b), use var\PYGZus{}index = 2.}
12+
13+
\PYG{l+s+sd}{ returns:}
14+
\PYG{l+s+sd}{ a function that can sample from the univariate conditional distribution}
15+
\PYG{l+s+sd}{ \PYGZsq{}\PYGZsq{}\PYGZsq{}}
16+
\PYG{k}{assert} \PYG{n}{joint\PYGZus{}mu}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]} \PYG{o}{==} \PYG{l+m+mi}{2}\PYG{p}{,} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{Sorry, this function only works for 2\PYGZhy{}dimensional joint distributions right now}\PYG{l+s+s1}{\PYGZsq{}}
17+
\PYG{n}{a} \PYG{o}{=} \PYG{n}{joint\PYGZus{}mu}\PYG{p}{[}\PYG{n}{var\PYGZus{}index}\PYG{p}{]}
18+
\PYG{n}{b} \PYG{o}{=} \PYG{n}{joint\PYGZus{}mu}\PYG{p}{[}\PYG{o}{\PYGZti{}}\PYG{n}{var\PYGZus{}index}\PYG{p}{]}
19+
20+
\PYG{n}{A} \PYG{o}{=} \PYG{n}{joint\PYGZus{}cov}\PYG{p}{[}\PYG{n}{var\PYGZus{}index}\PYG{p}{,} \PYG{n}{var\PYGZus{}index}\PYG{p}{]}
21+
\PYG{n}{B} \PYG{o}{=} \PYG{n}{joint\PYGZus{}cov}\PYG{p}{[}\PYG{o}{\PYGZti{}}\PYG{n}{var\PYGZus{}index}\PYG{p}{,} \PYG{o}{\PYGZti{}}\PYG{n}{var\PYGZus{}index}\PYG{p}{]}
22+
\PYG{n}{C} \PYG{o}{=} \PYG{n}{joint\PYGZus{}cov}\PYG{p}{[}\PYG{n}{var\PYGZus{}index}\PYG{p}{,} \PYG{o}{\PYGZti{}}\PYG{n}{var\PYGZus{}index}\PYG{p}{]}
23+
24+
\PYG{c+c1}{\PYGZsh{} we\PYGZsq{}re dealing with one dimension so}
25+
\PYG{n}{B\PYGZus{}inv} \PYG{o}{=} \PYG{l+m+mi}{1}\PYG{o}{/}\PYG{n}{B}
26+
27+
\PYG{c+c1}{\PYGZsh{} Return a function that can sample given a value of g}
28+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{dist}\PYG{p}{(}\PYG{n}{g}\PYG{p}{)}\PYG{p}{:}
29+
\PYG{c+c1}{\PYGZsh{} a + C*B\PYGZca{}\PYGZob{}\PYGZhy{}1\PYGZcb{}(g \PYGZhy{} b)}
30+
\PYG{n}{mu} \PYG{o}{=} \PYG{n}{a} \PYG{o}{+} \PYG{n}{C} \PYG{o}{*} \PYG{n}{B\PYGZus{}inv} \PYG{o}{*} \PYG{p}{(}\PYG{n}{g} \PYG{o}{\PYGZhy{}} \PYG{n}{b}\PYG{p}{)}
31+
\PYG{c+c1}{\PYGZsh{} A \PYGZhy{} C * B\PYGZca{}\PYGZob{}\PYGZhy{}1\PYGZcb{} * C\PYGZca{}T}
32+
\PYG{n}{cov} \PYG{o}{=} \PYG{n}{A} \PYG{o}{\PYGZhy{}} \PYG{n}{B\PYGZus{}inv} \PYG{o}{*} \PYG{n}{C} \PYG{o}{*} \PYG{n}{C}
33+
\PYG{k}{return} \PYG{n}{np}\PYG{o}{.}\PYG{n}{sqrt}\PYG{p}{(}\PYG{n}{cov}\PYG{p}{)} \PYG{o}{*} \PYG{n}{np}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{randn}\PYG{p}{(}\PYG{l+m+mi}{1}\PYG{p}{)} \PYG{o}{+} \PYG{n}{mu}
34+
35+
\PYG{k}{return} \PYG{n}{dist}
36+
37+
38+
\end{MintedVerbatim}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{n}{N} \PYG{o}{=} \PYG{l+m+mi}{10000}
3+
\PYG{n}{L} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{linalg}\PYG{o}{.}\PYG{n}{cholesky}\PYG{p}{(}\PYG{n}{joint\PYGZus{}cov}\PYG{p}{)}
4+
\PYG{n}{samples\PYGZus{}from\PYGZus{}true\PYGZus{}distribution} \PYG{o}{=} \PYG{n}{L} \PYG{o}{@} \PYG{n}{np}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{randn}\PYG{p}{(}\PYG{n}{D}\PYG{p}{,} \PYG{n}{N}\PYG{p}{)} \PYG{o}{+} \PYG{n}{joint\PYGZus{}mu}
5+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{plot}\PYG{p}{(}\PYG{o}{*}\PYG{n}{samples\PYGZus{}from\PYGZus{}true\PYGZus{}distribution}\PYG{p}{,} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{.}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,} \PYG{n}{alpha}\PYG{o}{=}\PYG{l+m+mf}{0.1}\PYG{p}{)}
6+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{axis}\PYG{p}{(}\PYG{p}{[}\PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{4}\PYG{p}{,} \PYG{l+m+mi}{4}\PYG{p}{,} \PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{4}\PYG{p}{,} \PYG{l+m+mi}{4}\PYG{p}{]}\PYG{p}{)}
7+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{show}\PYG{p}{(}\PYG{p}{)}
8+
9+
\end{MintedVerbatim}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{numpy}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{np}
3+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}
4+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{utils}\PYG{n+nn}{.}\PYG{n+nn}{data}
5+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{nn}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{nn}
6+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{nn}\PYG{n+nn}{.}\PYG{n+nn}{functional}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{F}
7+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{optim}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{optim}
8+
\PYG{k+kn}{from}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{autograd}\PYG{+w}{ }\PYG{k+kn}{import} \PYG{n}{Variable}
9+
\PYG{k+kn}{from}\PYG{+w}{ }\PYG{n+nn}{torchvision}\PYG{+w}{ }\PYG{k+kn}{import} \PYG{n}{datasets}\PYG{p}{,} \PYG{n}{transforms}
10+
\PYG{k+kn}{from}\PYG{+w}{ }\PYG{n+nn}{torchvision}\PYG{n+nn}{.}\PYG{n+nn}{utils}\PYG{+w}{ }\PYG{k+kn}{import} \PYG{n}{make\PYGZus{}grid} \PYG{p}{,} \PYG{n}{save\PYGZus{}image}
11+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{matplotlib}\PYG{n+nn}{.}\PYG{n+nn}{pyplot}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{plt}
12+
13+
14+
\PYG{n}{batch\PYGZus{}size} \PYG{o}{=} \PYG{l+m+mi}{64}
15+
\PYG{n}{train\PYGZus{}loader} \PYG{o}{=} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{utils}\PYG{o}{.}\PYG{n}{data}\PYG{o}{.}\PYG{n}{DataLoader}\PYG{p}{(}
16+
\PYG{n}{datasets}\PYG{o}{.}\PYG{n}{MNIST}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{./data}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,}
17+
\PYG{n}{train}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{,}
18+
\PYG{n}{download} \PYG{o}{=} \PYG{k+kc}{True}\PYG{p}{,}
19+
\PYG{n}{transform} \PYG{o}{=} \PYG{n}{transforms}\PYG{o}{.}\PYG{n}{Compose}\PYG{p}{(}
20+
\PYG{p}{[}\PYG{n}{transforms}\PYG{o}{.}\PYG{n}{ToTensor}\PYG{p}{(}\PYG{p}{)}\PYG{p}{]}\PYG{p}{)}
21+
\PYG{p}{)}\PYG{p}{,}
22+
\PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{n}{batch\PYGZus{}size}
23+
\PYG{p}{)}
24+
25+
\PYG{n}{test\PYGZus{}loader} \PYG{o}{=} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{utils}\PYG{o}{.}\PYG{n}{data}\PYG{o}{.}\PYG{n}{DataLoader}\PYG{p}{(}
26+
\PYG{n}{datasets}\PYG{o}{.}\PYG{n}{MNIST}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{./data}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,}
27+
\PYG{n}{train}\PYG{o}{=}\PYG{k+kc}{False}\PYG{p}{,}
28+
\PYG{n}{transform}\PYG{o}{=}\PYG{n}{transforms}\PYG{o}{.}\PYG{n}{Compose}\PYG{p}{(}
29+
\PYG{p}{[}\PYG{n}{transforms}\PYG{o}{.}\PYG{n}{ToTensor}\PYG{p}{(}\PYG{p}{)}\PYG{p}{]}\PYG{p}{)}
30+
\PYG{p}{)}\PYG{p}{,}
31+
\PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{n}{batch\PYGZus{}size}\PYG{p}{)}
32+
33+
34+
\PYG{k}{class}\PYG{+w}{ }\PYG{n+nc}{RBM}\PYG{p}{(}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Module}\PYG{p}{)}\PYG{p}{:}
35+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}
36+
\PYG{n}{n\PYGZus{}vis}\PYG{o}{=}\PYG{l+m+mi}{784}\PYG{p}{,}
37+
\PYG{n}{n\PYGZus{}hin}\PYG{o}{=}\PYG{l+m+mi}{500}\PYG{p}{,}
38+
\PYG{n}{k}\PYG{o}{=}\PYG{l+m+mi}{5}\PYG{p}{)}\PYG{p}{:}
39+
\PYG{n+nb}{super}\PYG{p}{(}\PYG{n}{RBM}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{p}{)}\PYG{o}{.}\PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{p}{)}
40+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{Parameter}\PYG{p}{(}\PYG{n}{torch}\PYG{o}{.}\PYG{n}{randn}\PYG{p}{(}\PYG{n}{n\PYGZus{}hin}\PYG{p}{,}\PYG{n}{n\PYGZus{}vis}\PYG{p}{)}\PYG{o}{*}\PYG{l+m+mf}{1e\PYGZhy{}2}\PYG{p}{)}
41+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{Parameter}\PYG{p}{(}\PYG{n}{torch}\PYG{o}{.}\PYG{n}{zeros}\PYG{p}{(}\PYG{n}{n\PYGZus{}vis}\PYG{p}{)}\PYG{p}{)}
42+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{Parameter}\PYG{p}{(}\PYG{n}{torch}\PYG{o}{.}\PYG{n}{zeros}\PYG{p}{(}\PYG{n}{n\PYGZus{}hin}\PYG{p}{)}\PYG{p}{)}
43+
\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{k} \PYG{o}{=} \PYG{n}{k}
44+
45+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{sample\PYGZus{}from\PYGZus{}p}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}\PYG{n}{p}\PYG{p}{)}\PYG{p}{:}
46+
\PYG{k}{return} \PYG{n}{F}\PYG{o}{.}\PYG{n}{relu}\PYG{p}{(}\PYG{n}{torch}\PYG{o}{.}\PYG{n}{sign}\PYG{p}{(}\PYG{n}{p} \PYG{o}{\PYGZhy{}} \PYG{n}{Variable}\PYG{p}{(}\PYG{n}{torch}\PYG{o}{.}\PYG{n}{rand}\PYG{p}{(}\PYG{n}{p}\PYG{o}{.}\PYG{n}{size}\PYG{p}{(}\PYG{p}{)}\PYG{p}{)}\PYG{p}{)}\PYG{p}{)}\PYG{p}{)}
47+
48+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{v\PYGZus{}to\PYGZus{}h}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}\PYG{n}{v}\PYG{p}{)}\PYG{p}{:}
49+
\PYG{n}{p\PYGZus{}h} \PYG{o}{=} \PYG{n}{F}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{F}\PYG{o}{.}\PYG{n}{linear}\PYG{p}{(}\PYG{n}{v}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias}\PYG{p}{)}\PYG{p}{)}
50+
\PYG{n}{sample\PYGZus{}h} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sample\PYGZus{}from\PYGZus{}p}\PYG{p}{(}\PYG{n}{p\PYGZus{}h}\PYG{p}{)}
51+
\PYG{k}{return} \PYG{n}{p\PYGZus{}h}\PYG{p}{,}\PYG{n}{sample\PYGZus{}h}
52+
53+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{h\PYGZus{}to\PYGZus{}v}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}\PYG{n}{h}\PYG{p}{)}\PYG{p}{:}
54+
\PYG{n}{p\PYGZus{}v} \PYG{o}{=} \PYG{n}{F}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{F}\PYG{o}{.}\PYG{n}{linear}\PYG{p}{(}\PYG{n}{h}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{o}{.}\PYG{n}{t}\PYG{p}{(}\PYG{p}{)}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias}\PYG{p}{)}\PYG{p}{)}
55+
\PYG{n}{sample\PYGZus{}v} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sample\PYGZus{}from\PYGZus{}p}\PYG{p}{(}\PYG{n}{p\PYGZus{}v}\PYG{p}{)}
56+
\PYG{k}{return} \PYG{n}{p\PYGZus{}v}\PYG{p}{,}\PYG{n}{sample\PYGZus{}v}
57+
58+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{forward}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}\PYG{n}{v}\PYG{p}{)}\PYG{p}{:}
59+
\PYG{n}{pre\PYGZus{}h1}\PYG{p}{,}\PYG{n}{h1} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}to\PYGZus{}h}\PYG{p}{(}\PYG{n}{v}\PYG{p}{)}
60+
61+
\PYG{n}{h\PYGZus{}} \PYG{o}{=} \PYG{n}{h1}
62+
\PYG{k}{for} \PYG{n}{\PYGZus{}} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{k}\PYG{p}{)}\PYG{p}{:}
63+
\PYG{n}{pre\PYGZus{}v\PYGZus{}}\PYG{p}{,}\PYG{n}{v\PYGZus{}} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}to\PYGZus{}v}\PYG{p}{(}\PYG{n}{h\PYGZus{}}\PYG{p}{)}
64+
\PYG{n}{pre\PYGZus{}h\PYGZus{}}\PYG{p}{,}\PYG{n}{h\PYGZus{}} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}to\PYGZus{}h}\PYG{p}{(}\PYG{n}{v\PYGZus{}}\PYG{p}{)}
65+
66+
\PYG{k}{return} \PYG{n}{v}\PYG{p}{,}\PYG{n}{v\PYGZus{}}
67+
68+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{free\PYGZus{}energy}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,}\PYG{n}{v}\PYG{p}{)}\PYG{p}{:}
69+
\PYG{n}{vbias\PYGZus{}term} \PYG{o}{=} \PYG{n}{v}\PYG{o}{.}\PYG{n}{mv}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias}\PYG{p}{)}
70+
\PYG{n}{wx\PYGZus{}b} \PYG{o}{=} \PYG{n}{F}\PYG{o}{.}\PYG{n}{linear}\PYG{p}{(}\PYG{n}{v}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{p}{,}\PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias}\PYG{p}{)}
71+
\PYG{n}{hidden\PYGZus{}term} \PYG{o}{=} \PYG{n}{wx\PYGZus{}b}\PYG{o}{.}\PYG{n}{exp}\PYG{p}{(}\PYG{p}{)}\PYG{o}{.}\PYG{n}{add}\PYG{p}{(}\PYG{l+m+mi}{1}\PYG{p}{)}\PYG{o}{.}\PYG{n}{log}\PYG{p}{(}\PYG{p}{)}\PYG{o}{.}\PYG{n}{sum}\PYG{p}{(}\PYG{l+m+mi}{1}\PYG{p}{)}
72+
\PYG{k}{return} \PYG{p}{(}\PYG{o}{\PYGZhy{}}\PYG{n}{hidden\PYGZus{}term} \PYG{o}{\PYGZhy{}} \PYG{n}{vbias\PYGZus{}term}\PYG{p}{)}\PYG{o}{.}\PYG{n}{mean}\PYG{p}{(}\PYG{p}{)}
73+
74+
75+
76+
77+
\PYG{n}{rbm} \PYG{o}{=} \PYG{n}{RBM}\PYG{p}{(}\PYG{n}{k}\PYG{o}{=}\PYG{l+m+mi}{1}\PYG{p}{)}
78+
\PYG{n}{train\PYGZus{}op} \PYG{o}{=} \PYG{n}{optim}\PYG{o}{.}\PYG{n}{SGD}\PYG{p}{(}\PYG{n}{rbm}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{(}\PYG{p}{)}\PYG{p}{,}\PYG{l+m+mf}{0.1}\PYG{p}{)}
79+
80+
\PYG{k}{for} \PYG{n}{epoch} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{l+m+mi}{10}\PYG{p}{)}\PYG{p}{:}
81+
\PYG{n}{loss\PYGZus{}} \PYG{o}{=} \PYG{p}{[}\PYG{p}{]}
82+
\PYG{k}{for} \PYG{n}{\PYGZus{}}\PYG{p}{,} \PYG{p}{(}\PYG{n}{data}\PYG{p}{,}\PYG{n}{target}\PYG{p}{)} \PYG{o+ow}{in} \PYG{n+nb}{enumerate}\PYG{p}{(}\PYG{n}{train\PYGZus{}loader}\PYG{p}{)}\PYG{p}{:}
83+
\PYG{n}{data} \PYG{o}{=} \PYG{n}{Variable}\PYG{p}{(}\PYG{n}{data}\PYG{o}{.}\PYG{n}{view}\PYG{p}{(}\PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{,}\PYG{l+m+mi}{784}\PYG{p}{)}\PYG{p}{)}
84+
\PYG{n}{sample\PYGZus{}data} \PYG{o}{=} \PYG{n}{data}\PYG{o}{.}\PYG{n}{bernoulli}\PYG{p}{(}\PYG{p}{)}
85+
86+
\PYG{n}{v}\PYG{p}{,}\PYG{n}{v1} \PYG{o}{=} \PYG{n}{rbm}\PYG{p}{(}\PYG{n}{sample\PYGZus{}data}\PYG{p}{)}
87+
\PYG{n}{loss} \PYG{o}{=} \PYG{n}{rbm}\PYG{o}{.}\PYG{n}{free\PYGZus{}energy}\PYG{p}{(}\PYG{n}{v}\PYG{p}{)} \PYG{o}{\PYGZhy{}} \PYG{n}{rbm}\PYG{o}{.}\PYG{n}{free\PYGZus{}energy}\PYG{p}{(}\PYG{n}{v1}\PYG{p}{)}
88+
\PYG{n}{loss\PYGZus{}}\PYG{o}{.}\PYG{n}{append}\PYG{p}{(}\PYG{n}{loss}\PYG{o}{.}\PYG{n}{data}\PYG{p}{)}
89+
\PYG{n}{train\PYGZus{}op}\PYG{o}{.}\PYG{n}{zero\PYGZus{}grad}\PYG{p}{(}\PYG{p}{)}
90+
\PYG{n}{loss}\PYG{o}{.}\PYG{n}{backward}\PYG{p}{(}\PYG{p}{)}
91+
\PYG{n}{train\PYGZus{}op}\PYG{o}{.}\PYG{n}{step}\PYG{p}{(}\PYG{p}{)}
92+
93+
\PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}}\PYG{l+s+s2}{Training loss for }\PYG{l+s+si}{\PYGZob{}\PYGZcb{}}\PYG{l+s+s2}{ epoch: }\PYG{l+s+si}{\PYGZob{}\PYGZcb{}}\PYG{l+s+s2}{\PYGZdq{}}\PYG{o}{.}\PYG{n}{format}\PYG{p}{(}\PYG{n}{epoch}\PYG{p}{,} \PYG{n}{np}\PYG{o}{.}\PYG{n}{mean}\PYG{p}{(}\PYG{n}{loss\PYGZus{}}\PYG{p}{)}\PYG{p}{)}\PYG{p}{)}
94+
95+
96+
\PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{show\PYGZus{}adn\PYGZus{}save}\PYG{p}{(}\PYG{n}{file\PYGZus{}name}\PYG{p}{,}\PYG{n}{img}\PYG{p}{)}\PYG{p}{:}
97+
\PYG{n}{npimg} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{transpose}\PYG{p}{(}\PYG{n}{img}\PYG{o}{.}\PYG{n}{numpy}\PYG{p}{(}\PYG{p}{)}\PYG{p}{,}\PYG{p}{(}\PYG{l+m+mi}{1}\PYG{p}{,}\PYG{l+m+mi}{2}\PYG{p}{,}\PYG{l+m+mi}{0}\PYG{p}{)}\PYG{p}{)}
98+
\PYG{n}{f} \PYG{o}{=} \PYG{l+s+s2}{\PYGZdq{}}\PYG{l+s+s2}{./}\PYG{l+s+si}{\PYGZpc{}s}\PYG{l+s+s2}{.png}\PYG{l+s+s2}{\PYGZdq{}} \PYG{o}{\PYGZpc{}} \PYG{n}{file\PYGZus{}name}
99+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{imshow}\PYG{p}{(}\PYG{n}{npimg}\PYG{p}{)}
100+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{imsave}\PYG{p}{(}\PYG{n}{f}\PYG{p}{,}\PYG{n}{npimg}\PYG{p}{)}
101+
102+
\PYG{n}{show\PYGZus{}adn\PYGZus{}save}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}}\PYG{l+s+s2}{real}\PYG{l+s+s2}{\PYGZdq{}}\PYG{p}{,}\PYG{n}{make\PYGZus{}grid}\PYG{p}{(}\PYG{n}{v}\PYG{o}{.}\PYG{n}{view}\PYG{p}{(}\PYG{l+m+mi}{32}\PYG{p}{,}\PYG{l+m+mi}{1}\PYG{p}{,}\PYG{l+m+mi}{28}\PYG{p}{,}\PYG{l+m+mi}{28}\PYG{p}{)}\PYG{o}{.}\PYG{n}{data}\PYG{p}{)}\PYG{p}{)}
103+
\PYG{n}{show\PYGZus{}adn\PYGZus{}save}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}}\PYG{l+s+s2}{generate}\PYG{l+s+s2}{\PYGZdq{}}\PYG{p}{,}\PYG{n}{make\PYGZus{}grid}\PYG{p}{(}\PYG{n}{v1}\PYG{o}{.}\PYG{n}{view}\PYG{p}{(}\PYG{l+m+mi}{32}\PYG{p}{,}\PYG{l+m+mi}{1}\PYG{p}{,}\PYG{l+m+mi}{28}\PYG{p}{,}\PYG{l+m+mi}{28}\PYG{p}{)}\PYG{o}{.}\PYG{n}{data}\PYG{p}{)}\PYG{p}{)}
104+
105+
106+
\end{MintedVerbatim}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{n}{samples} \PYG{o}{=} \PYG{n}{gibbs\PYGZus{}sample}\PYG{p}{(}\PYG{n}{univariate\PYGZus{}conditionals}\PYG{p}{,} \PYG{n}{sample\PYGZus{}count}\PYG{o}{=}\PYG{l+m+mi}{100}\PYG{p}{)}
3+
\PYG{n}{fig}\PYG{p}{,} \PYG{n}{ax} \PYG{o}{=} \PYG{n}{plt}\PYG{o}{.}\PYG{n}{subplots}\PYG{p}{(}\PYG{p}{)}
4+
5+
\PYG{n}{ax}\PYG{o}{.}\PYG{n}{plot}\PYG{p}{(}\PYG{o}{*}\PYG{n}{samples\PYGZus{}from\PYGZus{}true\PYGZus{}distribution}\PYG{p}{,} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{.}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,} \PYG{n}{alpha}\PYG{o}{=}\PYG{l+m+mf}{0.1}\PYG{p}{)}
6+
\PYG{n}{ax}\PYG{o}{.}\PYG{n}{plot}\PYG{p}{(}\PYG{o}{*}\PYG{n}{samples}\PYG{p}{,} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{k}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)}
7+
\PYG{n}{ax}\PYG{o}{.}\PYG{n}{plot}\PYG{p}{(}\PYG{o}{*}\PYG{n}{samples}\PYG{p}{,} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{.r}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)}
8+
\PYG{n}{ax}\PYG{o}{.}\PYG{n}{axis}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{square}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)}
9+
\PYG{n}{plt}\PYG{o}{.}\PYG{n}{show}\PYG{p}{(}\PYG{p}{)}
10+
11+
\end{MintedVerbatim}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}]
2+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{numpy}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{np}
3+
\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{matplotlib}\PYG{n+nn}{.}\PYG{n+nn}{pyplot}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{plt}
4+
\PYG{c+c1}{\PYGZsh{} two dimensions}
5+
\PYG{n}{D} \PYG{o}{=} \PYG{l+m+mi}{2}
6+
7+
\PYG{c+c1}{\PYGZsh{} set up the means (standard normal distribution}
8+
\PYG{n}{a\PYGZus{}mu} \PYG{o}{=} \PYG{l+m+mi}{0}
9+
\PYG{n}{b\PYGZus{}mu} \PYG{o}{=} \PYG{l+m+mi}{0}
10+
\PYG{c+c1}{\PYGZsh{} and the variances and covariances}
11+
\PYG{n}{a\PYGZus{}sigma} \PYG{o}{=} \PYG{l+m+mi}{1}
12+
\PYG{n}{b\PYGZus{}sigma} \PYG{o}{=} \PYG{l+m+mi}{1}
13+
\PYG{n}{a\PYGZus{}b\PYGZus{}cov} \PYG{o}{=} \PYG{l+m+mf}{0.5}
14+
15+
\PYG{n}{joint\PYGZus{}cov} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{vstack}\PYG{p}{(}\PYG{p}{(}\PYG{p}{(}\PYG{n}{a\PYGZus{}sigma}\PYG{p}{,} \PYG{n}{a\PYGZus{}b\PYGZus{}cov}\PYG{p}{)}\PYG{p}{,} \PYG{p}{(}\PYG{n}{a\PYGZus{}b\PYGZus{}cov}\PYG{p}{,} \PYG{n}{b\PYGZus{}sigma}\PYG{p}{)}\PYG{p}{)}\PYG{p}{)}
16+
\PYG{n}{joint\PYGZus{}mu} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{vstack}\PYG{p}{(}\PYG{p}{(}\PYG{n}{a\PYGZus{}mu}\PYG{p}{,} \PYG{n}{b\PYGZus{}mu}\PYG{p}{)}\PYG{p}{)}
17+
18+
\end{MintedVerbatim}

0 commit comments

Comments
 (0)