add function to load weights in the network
This commit is contained in:
parent
1be5ba9137
commit
d16d56ff13
35
mlp.py
35
mlp.py
|
@ -123,11 +123,6 @@ class MultiLayerPerceptron(object):
|
|||
self._softmax = False
|
||||
if g[-1]["name"] == "softmax":
|
||||
self._softmax = True
|
||||
self._W = [None] * (L+1)
|
||||
self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
|
||||
assert(len(self._g) == len(self._W))
|
||||
assert(len(self._g) == len(self._b))
|
||||
assert(len(self._g) == len(self._n))
|
||||
self._A = None
|
||||
self._X = None
|
||||
self._Y = None
|
||||
|
@ -141,8 +136,13 @@ class MultiLayerPerceptron(object):
|
|||
self._rmsprop = False
|
||||
self._adam = False
|
||||
# initialise weights
|
||||
self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
|
||||
self._W = [None] + [np.zeros((n[l+1], n[l])) for l in range(L)]
|
||||
if set_random_w:
|
||||
self.init_random_weights(use_formula_w, w_rand_factor)
|
||||
assert(len(self._g) == len(self._W))
|
||||
assert(len(self._g) == len(self._b))
|
||||
assert(len(self._g) == len(self._n))
|
||||
|
||||
def init_random_weights(self, use_formula=False, w_rand_factor=1):
|
||||
"""Initialize randomly weights using a factor or using some formula
|
||||
|
@ -229,12 +229,13 @@ class MultiLayerPerceptron(object):
|
|||
assert(len(X) == m)
|
||||
self._X = np.matrix(X).T
|
||||
else:
|
||||
#print(X.shape, self._n[0], m)
|
||||
print(X.shape, self._n[0], m)
|
||||
assert(X.shape == (self._n[0], m))
|
||||
self._X = X
|
||||
self._m = m
|
||||
assert((self._m == m) or (self._m == 0))
|
||||
self._m = m
|
||||
self._prepared = False
|
||||
|
||||
def set_all_expected_output_examples(self, Y, m=None):
|
||||
"""Set the output examples
|
||||
|
@ -324,6 +325,28 @@ class MultiLayerPerceptron(object):
|
|||
def get_bias(self):
|
||||
return self._b[1:]
|
||||
|
||||
def set_flatten_weights(self, W):
|
||||
"""Set weights from a flatten list"""
|
||||
shapes = [w.shape for w in self._W[1:]]
|
||||
sizes = [w.size for w in self._W[1:]]
|
||||
flat_size = sum(sizes)
|
||||
assert(len(W) == flat_size)
|
||||
ini = 0
|
||||
for i, (size, shape) in enumerate(zip(sizes, shapes)):
|
||||
self._W[1+i] = np.reshape(W[ini:ini+size], shape)
|
||||
ini += size
|
||||
|
||||
def set_flatten_bias(self, B):
|
||||
"""Set bias from a flatten list"""
|
||||
shapes = [b.shape for b in self._b[1:]]
|
||||
sizes = [b.size for b in self._b[1:]]
|
||||
flat_size = sum(sizes)
|
||||
assert(len(B) == flat_size)
|
||||
ini = 0
|
||||
for i, (size, shape) in enumerate(zip(sizes, shapes)):
|
||||
self._b[1+i] = np.reshape(B[ini:ini+size], shape)
|
||||
ini += size
|
||||
|
||||
def back_propagation(self, get_cost_function=False):
|
||||
"""Back propagation
|
||||
|
||||
|
|
Loading…
Reference in New Issue