update prepare and compute output for using given inputs
This commit is contained in:
parent
a3011adf31
commit
d0fc2bd3e9
40
mlp.py
40
mlp.py
|
@ -229,7 +229,7 @@ class MultiLayerPerceptron(object):
|
|||
assert(len(X) == m)
|
||||
self._X = np.matrix(X).T
|
||||
else:
|
||||
print(X.shape, self._n[0], m)
|
||||
#print(X.shape, self._n[0], m)
|
||||
assert(X.shape == (self._n[0], m))
|
||||
self._X = X
|
||||
self._m = m
|
||||
|
@ -268,14 +268,37 @@ class MultiLayerPerceptron(object):
|
|||
self.set_all_input_examples(X, m)
|
||||
self.set_all_expected_output_examples(Y, m)
|
||||
|
||||
def prepare(self):
|
||||
def prepare(self, X=None, m=None, force=False):
|
||||
"""Prepare network for propagation"""
|
||||
if self._prepared == False:
|
||||
if X is not None:
|
||||
force = True
|
||||
#print("sforce,prep =", force, self._prepared)
|
||||
if force or self._prepared == False:
|
||||
if not force:
|
||||
self._prepared = True
|
||||
if X is None:
|
||||
assert(self._X is not None)
|
||||
X = self._X
|
||||
if m is None:
|
||||
m = self._m
|
||||
if m == 0:
|
||||
if type(X) is list:
|
||||
m = 1
|
||||
else:
|
||||
m = X.shape[1]
|
||||
else:
|
||||
if m is None:
|
||||
if type(X) is list:
|
||||
m = 1
|
||||
else:
|
||||
m = X.shape[1]
|
||||
if type(X) is list:
|
||||
X = np.array(X).reshape(len(X), 1)
|
||||
if m is None:
|
||||
assert(self._m > 0)
|
||||
m = self._m
|
||||
self._A = [self._X]
|
||||
#print("m prep =", m)
|
||||
self._A = [X]
|
||||
self._A += [np.empty((self._n[l+1], m)) for l in range(self._L)]
|
||||
self._Z = [None] + [np.empty((self._n[l+1], m)) for l in range(self._L)]
|
||||
|
||||
|
@ -291,7 +314,7 @@ class MultiLayerPerceptron(object):
|
|||
self._A[l] = self._g[l]["function"](self._Z[l])
|
||||
return self._A[self._L]
|
||||
|
||||
def compute_outputs(self, X=None):
|
||||
def compute_outputs(self, X=None, m=None):
|
||||
"""Compute outputs with forward propagation.
|
||||
Note: if no input provided, then the input should have been set using
|
||||
either `set_all_input_examples()` or `set_all_training_examples()`.
|
||||
|
@ -301,11 +324,14 @@ class MultiLayerPerceptron(object):
|
|||
|
||||
"""
|
||||
if X is not None:
|
||||
if m is None:
|
||||
if type(X) is list:
|
||||
m = len(X)
|
||||
m = 1
|
||||
else:
|
||||
m = X.shape[1]
|
||||
self.set_all_input_examples(X, m)
|
||||
print("lenX,m",len(X),m)
|
||||
self.prepare(X, m)
|
||||
else:
|
||||
self.prepare()
|
||||
self.propagate()
|
||||
return self._A[self._L]
|
||||
|
|
Loading…
Reference in New Issue