-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainable-parameter.lisp
More file actions
39 lines (32 loc) · 1.22 KB
/
trainable-parameter.lisp
File metadata and controls
39 lines (32 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
(in-package #:lispnet)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;
;;; Trainable Parameters
(defclass trainable-parameter ()
((weights
:accessor weights)
(weights-value
:accessor weights-value
:initform nil)
(shape
:accessor weights-shape
:initarg :shape)
(trainable
:accessor trainable
:initarg :trainable
:initform t)))
(defmethod initialize-instance :after ((parameter trainable-parameter) &rest initargs)
(setf (weights parameter)(make-unknown :shape (weights-shape parameter) :element-type *network-precision*)))
(defun make-trainable-parameter (model &key shape (trainable t))
(if (not (compiled (model-backend model)))
(let ((parameter (make-instance 'trainable-parameter
:shape shape
:trainable trainable)))
(push parameter (parameters (model-backend model)))
parameter)
(let ((parameter (nth (parameter-pointer (model-backend model)) (model-weights model))))
(decf (parameter-pointer (model-backend model)))
parameter)))
(declaim (inline trainable-parameter-p))
(defun trainable-parameter-p (object)
(typep object 'trainable-parameter))