From d16d56ff13f3124a775fe7c68921247c197f99b0 Mon Sep 17 00:00:00 2001 From: Daouzli A Date: Tue, 16 Jan 2018 22:44:26 +0100 Subject: [PATCH] add function to load weights in the network --- mlp.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/mlp.py b/mlp.py index 08b173f..34a62d6 100644 --- a/mlp.py +++ b/mlp.py @@ -123,11 +123,6 @@ class MultiLayerPerceptron(object): self._softmax = False if g[-1]["name"] == "softmax": self._softmax = True - self._W = [None] * (L+1) - self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)] - assert(len(self._g) == len(self._W)) - assert(len(self._g) == len(self._b)) - assert(len(self._g) == len(self._n)) self._A = None self._X = None self._Y = None @@ -141,8 +136,13 @@ class MultiLayerPerceptron(object): self._rmsprop = False self._adam = False # initialise weights + self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)] + self._W = [None] + [np.zeros((n[l+1], n[l])) for l in range(L)] if set_random_w: self.init_random_weights(use_formula_w, w_rand_factor) + assert(len(self._g) == len(self._W)) + assert(len(self._g) == len(self._b)) + assert(len(self._g) == len(self._n)) def init_random_weights(self, use_formula=False, w_rand_factor=1): """Initialize randomly weights using a factor or using some formula @@ -229,12 +229,13 @@ class MultiLayerPerceptron(object): assert(len(X) == m) self._X = np.matrix(X).T else: - #print(X.shape, self._n[0], m) + print(X.shape, self._n[0], m) assert(X.shape == (self._n[0], m)) self._X = X self._m = m assert((self._m == m) or (self._m == 0)) self._m = m + self._prepared = False def set_all_expected_output_examples(self, Y, m=None): """Set the output examples @@ -324,6 +325,28 @@ class MultiLayerPerceptron(object): def get_bias(self): return self._b[1:] + def set_flatten_weights(self, W): + """Set weights from a flatten list""" + shapes = [w.shape for w in self._W[1:]] + sizes = [w.size for w in self._W[1:]] + flat_size = sum(sizes) + assert(len(W) == flat_size) + ini = 0 + for i, (size, shape) in enumerate(zip(sizes, shapes)): + self._W[1+i] = np.reshape(W[ini:ini+size], shape) + ini += size + + def set_flatten_bias(self, B): + """Set bias from a flatten list""" + shapes = [b.shape for b in self._b[1:]] + sizes = [b.size for b in self._b[1:]] + flat_size = sum(sizes) + assert(len(B) == flat_size) + ini = 0 + for i, (size, shape) in enumerate(zip(sizes, shapes)): + self._b[1+i] = np.reshape(B[ini:ini+size], shape) + ini += size + def back_propagation(self, get_cost_function=False): """Back propagation