|
| 1 | +\begin{MintedVerbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}] |
| 2 | +\PYG{l+s+sd}{\PYGZdq{}\PYGZdq{}\PYGZdq{}} |
| 3 | +\PYG{l+s+sd}{Key components:} |
| 4 | +\PYG{l+s+sd}{1. **Data Handling**: Uses PyTorch DataLoader with MNIST dataset} |
| 5 | +\PYG{l+s+sd}{2. **LSTM Architecture**:} |
| 6 | +\PYG{l+s+sd}{ \PYGZhy{} Input sequence of 28 timesteps (image rows)} |
| 7 | +\PYG{l+s+sd}{ \PYGZhy{} 128 hidden units in LSTM layer} |
| 8 | +\PYG{l+s+sd}{ \PYGZhy{} Fully connected layer for classification} |
| 9 | +\PYG{l+s+sd}{3. **Training**:} |
| 10 | +\PYG{l+s+sd}{ \PYGZhy{} Cross\PYGZhy{}entropy loss} |
| 11 | +\PYG{l+s+sd}{ \PYGZhy{} Adam optimizer} |
| 12 | +\PYG{l+s+sd}{ \PYGZhy{} Automatic GPU utilization if available} |
| 13 | + |
| 14 | +\PYG{l+s+sd}{This implementation typically achieves **97\PYGZhy{}98\PYGZpc{} accuracy** after 10 epochs. The main differences from the TensorFlow/Keras version:} |
| 15 | +\PYG{l+s+sd}{\PYGZhy{} Explicit device management (CPU/GPU)} |
| 16 | +\PYG{l+s+sd}{\PYGZhy{} Manual training loop} |
| 17 | +\PYG{l+s+sd}{\PYGZhy{} Different data loading pipeline} |
| 18 | +\PYG{l+s+sd}{\PYGZhy{} More explicit tensor reshaping} |
| 19 | + |
| 20 | +\PYG{l+s+sd}{To improve performance, you could:} |
| 21 | +\PYG{l+s+sd}{1. Add dropout regularization} |
| 22 | +\PYG{l+s+sd}{2. Use bidirectional LSTM} |
| 23 | +\PYG{l+s+sd}{3. Implement learning rate scheduling} |
| 24 | +\PYG{l+s+sd}{4. Add batch normalization} |
| 25 | +\PYG{l+s+sd}{5. Increase model capacity (more layers/units)} |
| 26 | +\PYG{l+s+sd}{\PYGZdq{}\PYGZdq{}\PYGZdq{}} |
| 27 | + |
| 28 | +\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch} |
| 29 | +\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{nn}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{nn} |
| 30 | +\PYG{k+kn}{import}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{optim}\PYG{+w}{ }\PYG{k}{as}\PYG{+w}{ }\PYG{n+nn}{optim} |
| 31 | +\PYG{k+kn}{from}\PYG{+w}{ }\PYG{n+nn}{torchvision}\PYG{+w}{ }\PYG{k+kn}{import} \PYG{n}{datasets}\PYG{p}{,} \PYG{n}{transforms} |
| 32 | +\PYG{k+kn}{from}\PYG{+w}{ }\PYG{n+nn}{torch}\PYG{n+nn}{.}\PYG{n+nn}{utils}\PYG{n+nn}{.}\PYG{n+nn}{data}\PYG{+w}{ }\PYG{k+kn}{import} \PYG{n}{DataLoader} |
| 33 | + |
| 34 | +\PYG{c+c1}{\PYGZsh{} Hyperparameters} |
| 35 | +\PYG{n}{input\PYGZus{}size} \PYG{o}{=} \PYG{l+m+mi}{28} \PYG{c+c1}{\PYGZsh{} Number of features (pixels per row)} |
| 36 | +\PYG{n}{hidden\PYGZus{}size} \PYG{o}{=} \PYG{l+m+mi}{128} \PYG{c+c1}{\PYGZsh{} LSTM hidden state size} |
| 37 | +\PYG{n}{num\PYGZus{}classes} \PYG{o}{=} \PYG{l+m+mi}{10} \PYG{c+c1}{\PYGZsh{} Digits 0\PYGZhy{}9} |
| 38 | +\PYG{n}{num\PYGZus{}epochs} \PYG{o}{=} \PYG{l+m+mi}{10} \PYG{c+c1}{\PYGZsh{} Training iterations} |
| 39 | +\PYG{n}{batch\PYGZus{}size} \PYG{o}{=} \PYG{l+m+mi}{64} \PYG{c+c1}{\PYGZsh{} Batch size} |
| 40 | +\PYG{n}{learning\PYGZus{}rate} \PYG{o}{=} \PYG{l+m+mf}{0.001} |
| 41 | + |
| 42 | +\PYG{c+c1}{\PYGZsh{} Device configuration} |
| 43 | +\PYG{n}{device} \PYG{o}{=} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{device}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{cuda}\PYG{l+s+s1}{\PYGZsq{}} \PYG{k}{if} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{cuda}\PYG{o}{.}\PYG{n}{is\PYGZus{}available}\PYG{p}{(}\PYG{p}{)} \PYG{k}{else} \PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{cpu}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)} |
| 44 | + |
| 45 | +\PYG{c+c1}{\PYGZsh{} MNIST dataset} |
| 46 | +\PYG{n}{transform} \PYG{o}{=} \PYG{n}{transforms}\PYG{o}{.}\PYG{n}{Compose}\PYG{p}{(}\PYG{p}{[} |
| 47 | + \PYG{n}{transforms}\PYG{o}{.}\PYG{n}{ToTensor}\PYG{p}{(}\PYG{p}{)}\PYG{p}{,} |
| 48 | + \PYG{n}{transforms}\PYG{o}{.}\PYG{n}{Normalize}\PYG{p}{(}\PYG{p}{(}\PYG{l+m+mf}{0.1307}\PYG{p}{,}\PYG{p}{)}\PYG{p}{,} \PYG{p}{(}\PYG{l+m+mf}{0.3081}\PYG{p}{,}\PYG{p}{)}\PYG{p}{)} \PYG{c+c1}{\PYGZsh{} MNIST mean and std} |
| 49 | +\PYG{p}{]}\PYG{p}{)} |
| 50 | + |
| 51 | +\PYG{n}{train\PYGZus{}dataset} \PYG{o}{=} \PYG{n}{datasets}\PYG{o}{.}\PYG{n}{MNIST}\PYG{p}{(}\PYG{n}{root}\PYG{o}{=}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{./data}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,} |
| 52 | + \PYG{n}{train}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{,} |
| 53 | + \PYG{n}{transform}\PYG{o}{=}\PYG{n}{transform}\PYG{p}{,} |
| 54 | + \PYG{n}{download}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{)} |
| 55 | + |
| 56 | +\PYG{n}{test\PYGZus{}dataset} \PYG{o}{=} \PYG{n}{datasets}\PYG{o}{.}\PYG{n}{MNIST}\PYG{p}{(}\PYG{n}{root}\PYG{o}{=}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{./data}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{,} |
| 57 | + \PYG{n}{train}\PYG{o}{=}\PYG{k+kc}{False}\PYG{p}{,} |
| 58 | + \PYG{n}{transform}\PYG{o}{=}\PYG{n}{transform}\PYG{p}{)} |
| 59 | + |
| 60 | +\PYG{n}{train\PYGZus{}loader} \PYG{o}{=} \PYG{n}{DataLoader}\PYG{p}{(}\PYG{n}{dataset}\PYG{o}{=}\PYG{n}{train\PYGZus{}dataset}\PYG{p}{,} |
| 61 | + \PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{n}{batch\PYGZus{}size}\PYG{p}{,} |
| 62 | + \PYG{n}{shuffle}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{)} |
| 63 | + |
| 64 | +\PYG{n}{test\PYGZus{}loader} \PYG{o}{=} \PYG{n}{DataLoader}\PYG{p}{(}\PYG{n}{dataset}\PYG{o}{=}\PYG{n}{test\PYGZus{}dataset}\PYG{p}{,} |
| 65 | + \PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{n}{batch\PYGZus{}size}\PYG{p}{,} |
| 66 | + \PYG{n}{shuffle}\PYG{o}{=}\PYG{k+kc}{False}\PYG{p}{)} |
| 67 | + |
| 68 | +\PYG{c+c1}{\PYGZsh{} LSTM model} |
| 69 | +\PYG{k}{class}\PYG{+w}{ }\PYG{n+nc}{LSTMModel}\PYG{p}{(}\PYG{n}{nn}\PYG{o}{.}\PYG{n}{Module}\PYG{p}{)}\PYG{p}{:} |
| 70 | + \PYG{k}{def}\PYG{+w}{ }\PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{input\PYGZus{}size}\PYG{p}{,} \PYG{n}{hidden\PYGZus{}size}\PYG{p}{,} \PYG{n}{num\PYGZus{}classes}\PYG{p}{)}\PYG{p}{:} |
| 71 | + \PYG{n+nb}{super}\PYG{p}{(}\PYG{n}{LSTMModel}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{p}{)}\PYG{o}{.}\PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{p}{)} |
| 72 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{hidden\PYGZus{}size} \PYG{o}{=} \PYG{n}{hidden\PYGZus{}size} |
| 73 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{lstm} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{LSTM}\PYG{p}{(}\PYG{n}{input\PYGZus{}size}\PYG{p}{,} \PYG{n}{hidden\PYGZus{}size}\PYG{p}{,} \PYG{n}{batch\PYGZus{}first}\PYG{o}{=}\PYG{k+kc}{True}\PYG{p}{)} |
| 74 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{fc} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{Linear}\PYG{p}{(}\PYG{n}{hidden\PYGZus{}size}\PYG{p}{,} \PYG{n}{num\PYGZus{}classes}\PYG{p}{)} |
| 75 | + |
| 76 | + \PYG{k}{def}\PYG{+w}{ }\PYG{n+nf}{forward}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{x}\PYG{p}{)}\PYG{p}{:} |
| 77 | + \PYG{c+c1}{\PYGZsh{} Reshape input to (batch\PYGZus{}size, sequence\PYGZus{}length, input\PYGZus{}size)} |
| 78 | + \PYG{n}{x} \PYG{o}{=} \PYG{n}{x}\PYG{o}{.}\PYG{n}{reshape}\PYG{p}{(}\PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{,} \PYG{l+m+mi}{28}\PYG{p}{,} \PYG{l+m+mi}{28}\PYG{p}{)} |
| 79 | + |
| 80 | + \PYG{c+c1}{\PYGZsh{} Forward propagate LSTM} |
| 81 | + \PYG{n}{out}\PYG{p}{,} \PYG{n}{\PYGZus{}} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{lstm}\PYG{p}{(}\PYG{n}{x}\PYG{p}{)} \PYG{c+c1}{\PYGZsh{} out: (batch\PYGZus{}size, seq\PYGZus{}length, hidden\PYGZus{}size)} |
| 82 | + |
| 83 | + \PYG{c+c1}{\PYGZsh{} Decode the hidden state of the last time step} |
| 84 | + \PYG{n}{out} \PYG{o}{=} \PYG{n}{out}\PYG{p}{[}\PYG{p}{:}\PYG{p}{,} \PYG{o}{\PYGZhy{}}\PYG{l+m+mi}{1}\PYG{p}{,} \PYG{p}{:}\PYG{p}{]} |
| 85 | + \PYG{n}{out} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{fc}\PYG{p}{(}\PYG{n}{out}\PYG{p}{)} |
| 86 | + \PYG{k}{return} \PYG{n}{out} |
| 87 | + |
| 88 | +\PYG{c+c1}{\PYGZsh{} Initialize model} |
| 89 | +\PYG{n}{model} \PYG{o}{=} \PYG{n}{LSTMModel}\PYG{p}{(}\PYG{n}{input\PYGZus{}size}\PYG{p}{,} \PYG{n}{hidden\PYGZus{}size}\PYG{p}{,} \PYG{n}{num\PYGZus{}classes}\PYG{p}{)}\PYG{o}{.}\PYG{n}{to}\PYG{p}{(}\PYG{n}{device}\PYG{p}{)} |
| 90 | + |
| 91 | +\PYG{c+c1}{\PYGZsh{} Loss and optimizer} |
| 92 | +\PYG{n}{criterion} \PYG{o}{=} \PYG{n}{nn}\PYG{o}{.}\PYG{n}{CrossEntropyLoss}\PYG{p}{(}\PYG{p}{)} |
| 93 | +\PYG{n}{optimizer} \PYG{o}{=} \PYG{n}{optim}\PYG{o}{.}\PYG{n}{Adam}\PYG{p}{(}\PYG{n}{model}\PYG{o}{.}\PYG{n}{parameters}\PYG{p}{(}\PYG{p}{)}\PYG{p}{,} \PYG{n}{lr}\PYG{o}{=}\PYG{n}{learning\PYGZus{}rate}\PYG{p}{)} |
| 94 | + |
| 95 | +\PYG{c+c1}{\PYGZsh{} Training loop} |
| 96 | +\PYG{n}{total\PYGZus{}step} \PYG{o}{=} \PYG{n+nb}{len}\PYG{p}{(}\PYG{n}{train\PYGZus{}loader}\PYG{p}{)} |
| 97 | +\PYG{k}{for} \PYG{n}{epoch} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n}{num\PYGZus{}epochs}\PYG{p}{)}\PYG{p}{:} |
| 98 | + \PYG{n}{model}\PYG{o}{.}\PYG{n}{train}\PYG{p}{(}\PYG{p}{)} |
| 99 | + \PYG{k}{for} \PYG{n}{i}\PYG{p}{,} \PYG{p}{(}\PYG{n}{images}\PYG{p}{,} \PYG{n}{labels}\PYG{p}{)} \PYG{o+ow}{in} \PYG{n+nb}{enumerate}\PYG{p}{(}\PYG{n}{train\PYGZus{}loader}\PYG{p}{)}\PYG{p}{:} |
| 100 | + \PYG{n}{images} \PYG{o}{=} \PYG{n}{images}\PYG{o}{.}\PYG{n}{to}\PYG{p}{(}\PYG{n}{device}\PYG{p}{)} |
| 101 | + \PYG{n}{labels} \PYG{o}{=} \PYG{n}{labels}\PYG{o}{.}\PYG{n}{to}\PYG{p}{(}\PYG{n}{device}\PYG{p}{)} |
| 102 | + |
| 103 | + \PYG{c+c1}{\PYGZsh{} Forward pass} |
| 104 | + \PYG{n}{outputs} \PYG{o}{=} \PYG{n}{model}\PYG{p}{(}\PYG{n}{images}\PYG{p}{)} |
| 105 | + \PYG{n}{loss} \PYG{o}{=} \PYG{n}{criterion}\PYG{p}{(}\PYG{n}{outputs}\PYG{p}{,} \PYG{n}{labels}\PYG{p}{)} |
| 106 | + |
| 107 | + \PYG{c+c1}{\PYGZsh{} Backward and optimize} |
| 108 | + \PYG{n}{optimizer}\PYG{o}{.}\PYG{n}{zero\PYGZus{}grad}\PYG{p}{(}\PYG{p}{)} |
| 109 | + \PYG{n}{loss}\PYG{o}{.}\PYG{n}{backward}\PYG{p}{(}\PYG{p}{)} |
| 110 | + \PYG{n}{optimizer}\PYG{o}{.}\PYG{n}{step}\PYG{p}{(}\PYG{p}{)} |
| 111 | + |
| 112 | + \PYG{k}{if} \PYG{p}{(}\PYG{n}{i}\PYG{o}{+}\PYG{l+m+mi}{1}\PYG{p}{)} \PYG{o}{\PYGZpc{}} \PYG{l+m+mi}{100} \PYG{o}{==} \PYG{l+m+mi}{0}\PYG{p}{:} |
| 113 | + \PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+sa}{f}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{Epoch [}\PYG{l+s+si}{\PYGZob{}}\PYG{n}{epoch}\PYG{o}{+}\PYG{l+m+mi}{1}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{/}\PYG{l+s+si}{\PYGZob{}}\PYG{n}{num\PYGZus{}epochs}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{], Step [}\PYG{l+s+si}{\PYGZob{}}\PYG{n}{i}\PYG{o}{+}\PYG{l+m+mi}{1}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{/}\PYG{l+s+si}{\PYGZob{}}\PYG{n}{total\PYGZus{}step}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{], Loss: }\PYG{l+s+si}{\PYGZob{}}\PYG{n}{loss}\PYG{o}{.}\PYG{n}{item}\PYG{p}{(}\PYG{p}{)}\PYG{l+s+si}{:}\PYG{l+s+s1}{.4f}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)} |
| 114 | + |
| 115 | + \PYG{c+c1}{\PYGZsh{} Test the model} |
| 116 | + \PYG{n}{model}\PYG{o}{.}\PYG{n}{eval}\PYG{p}{(}\PYG{p}{)} |
| 117 | + \PYG{k}{with} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{no\PYGZus{}grad}\PYG{p}{(}\PYG{p}{)}\PYG{p}{:} |
| 118 | + \PYG{n}{correct} \PYG{o}{=} \PYG{l+m+mi}{0} |
| 119 | + \PYG{n}{total} \PYG{o}{=} \PYG{l+m+mi}{0} |
| 120 | + \PYG{k}{for} \PYG{n}{images}\PYG{p}{,} \PYG{n}{labels} \PYG{o+ow}{in} \PYG{n}{test\PYGZus{}loader}\PYG{p}{:} |
| 121 | + \PYG{n}{images} \PYG{o}{=} \PYG{n}{images}\PYG{o}{.}\PYG{n}{to}\PYG{p}{(}\PYG{n}{device}\PYG{p}{)} |
| 122 | + \PYG{n}{labels} \PYG{o}{=} \PYG{n}{labels}\PYG{o}{.}\PYG{n}{to}\PYG{p}{(}\PYG{n}{device}\PYG{p}{)} |
| 123 | + \PYG{n}{outputs} \PYG{o}{=} \PYG{n}{model}\PYG{p}{(}\PYG{n}{images}\PYG{p}{)} |
| 124 | + \PYG{n}{\PYGZus{}}\PYG{p}{,} \PYG{n}{predicted} \PYG{o}{=} \PYG{n}{torch}\PYG{o}{.}\PYG{n}{max}\PYG{p}{(}\PYG{n}{outputs}\PYG{o}{.}\PYG{n}{data}\PYG{p}{,} \PYG{l+m+mi}{1}\PYG{p}{)} |
| 125 | + \PYG{n}{total} \PYG{o}{+}\PYG{o}{=} \PYG{n}{labels}\PYG{o}{.}\PYG{n}{size}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{)} |
| 126 | + \PYG{n}{correct} \PYG{o}{+}\PYG{o}{=} \PYG{p}{(}\PYG{n}{predicted} \PYG{o}{==} \PYG{n}{labels}\PYG{p}{)}\PYG{o}{.}\PYG{n}{sum}\PYG{p}{(}\PYG{p}{)}\PYG{o}{.}\PYG{n}{item}\PYG{p}{(}\PYG{p}{)} |
| 127 | + |
| 128 | + \PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+sa}{f}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{Test Accuracy: }\PYG{l+s+si}{\PYGZob{}}\PYG{l+m+mi}{100}\PYG{+w}{ }\PYG{o}{*}\PYG{+w}{ }\PYG{n}{correct}\PYG{+w}{ }\PYG{o}{/}\PYG{+w}{ }\PYG{n}{total}\PYG{l+s+si}{:}\PYG{l+s+s1}{.2f}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s1}{\PYGZpc{}}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)} |
| 129 | + |
| 130 | +\PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}}\PYG{l+s+s1}{Training finished.}\PYG{l+s+s1}{\PYGZsq{}}\PYG{p}{)} |
| 131 | + |
| 132 | +\end{MintedVerbatim} |
0 commit comments