update prepare and compute output for using given inputs

This commit is contained in:
dadel 2018-01-19 22:50:52 +01:00
parent a3011adf31
commit d0fc2bd3e9
1 changed files with 41 additions and 15 deletions

56
mlp.py
View File

@ -229,7 +229,7 @@ class MultiLayerPerceptron(object):
assert(len(X) == m)
self._X = np.matrix(X).T
else:
print(X.shape, self._n[0], m)
#print(X.shape, self._n[0], m)
assert(X.shape == (self._n[0], m))
self._X = X
self._m = m
@ -268,14 +268,37 @@ class MultiLayerPerceptron(object):
self.set_all_input_examples(X, m)
self.set_all_expected_output_examples(Y, m)
def prepare(self):
def prepare(self, X=None, m=None, force=False):
"""Prepare network for propagation"""
if self._prepared == False:
self._prepared = True
assert(self._X is not None)
assert(self._m > 0)
m = self._m
self._A = [self._X]
if X is not None:
force = True
#print("sforce,prep =", force, self._prepared)
if force or self._prepared == False:
if not force:
self._prepared = True
if X is None:
assert(self._X is not None)
X = self._X
if m is None:
m = self._m
if m == 0:
if type(X) is list:
m = 1
else:
m = X.shape[1]
else:
if m is None:
if type(X) is list:
m = 1
else:
m = X.shape[1]
if type(X) is list:
X = np.array(X).reshape(len(X), 1)
if m is None:
assert(self._m > 0)
m = self._m
#print("m prep =", m)
self._A = [X]
self._A += [np.empty((self._n[l+1], m)) for l in range(self._L)]
self._Z = [None] + [np.empty((self._n[l+1], m)) for l in range(self._L)]
@ -291,7 +314,7 @@ class MultiLayerPerceptron(object):
self._A[l] = self._g[l]["function"](self._Z[l])
return self._A[self._L]
def compute_outputs(self, X=None):
def compute_outputs(self, X=None, m=None):
"""Compute outputs with forward propagation.
Note: if no input provided, then the input should have been set using
either `set_all_input_examples()` or `set_all_training_examples()`.
@ -301,12 +324,15 @@ class MultiLayerPerceptron(object):
"""
if X is not None:
if type(X) is list:
m = len(X)
else:
m = X.shape[1]
self.set_all_input_examples(X, m)
self.prepare()
if m is None:
if type(X) is list:
m = 1
else:
m = X.shape[1]
print("lenX,m",len(X),m)
self.prepare(X, m)
else:
self.prepare()
self.propagate()
return self._A[self._L]