add function to load weights in the network

This commit is contained in:
dadel 2018-01-16 22:44:26 +01:00
parent 1be5ba9137
commit d16d56ff13
1 changed files with 29 additions and 6 deletions

35
mlp.py
View File

@ -123,11 +123,6 @@ class MultiLayerPerceptron(object):
self._softmax = False
if g[-1]["name"] == "softmax":
self._softmax = True
self._W = [None] * (L+1)
self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
assert(len(self._g) == len(self._W))
assert(len(self._g) == len(self._b))
assert(len(self._g) == len(self._n))
self._A = None
self._X = None
self._Y = None
@ -141,8 +136,13 @@ class MultiLayerPerceptron(object):
self._rmsprop = False
self._adam = False
# initialise weights
self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
self._W = [None] + [np.zeros((n[l+1], n[l])) for l in range(L)]
if set_random_w:
self.init_random_weights(use_formula_w, w_rand_factor)
assert(len(self._g) == len(self._W))
assert(len(self._g) == len(self._b))
assert(len(self._g) == len(self._n))
def init_random_weights(self, use_formula=False, w_rand_factor=1):
"""Initialize randomly weights using a factor or using some formula
@ -229,12 +229,13 @@ class MultiLayerPerceptron(object):
assert(len(X) == m)
self._X = np.matrix(X).T
else:
#print(X.shape, self._n[0], m)
print(X.shape, self._n[0], m)
assert(X.shape == (self._n[0], m))
self._X = X
self._m = m
assert((self._m == m) or (self._m == 0))
self._m = m
self._prepared = False
def set_all_expected_output_examples(self, Y, m=None):
"""Set the output examples
@ -324,6 +325,28 @@ class MultiLayerPerceptron(object):
def get_bias(self):
return self._b[1:]
def set_flatten_weights(self, W):
"""Set weights from a flatten list"""
shapes = [w.shape for w in self._W[1:]]
sizes = [w.size for w in self._W[1:]]
flat_size = sum(sizes)
assert(len(W) == flat_size)
ini = 0
for i, (size, shape) in enumerate(zip(sizes, shapes)):
self._W[1+i] = np.reshape(W[ini:ini+size], shape)
ini += size
def set_flatten_bias(self, B):
"""Set bias from a flatten list"""
shapes = [b.shape for b in self._b[1:]]
sizes = [b.size for b in self._b[1:]]
flat_size = sum(sizes)
assert(len(B) == flat_size)
ini = 0
for i, (size, shape) in enumerate(zip(sizes, shapes)):
self._b[1+i] = np.reshape(B[ini:ini+size], shape)
ini += size
def back_propagation(self, get_cost_function=False):
"""Back propagation