add function to load weights in the network
This commit is contained in:
		
							parent
							
								
									1be5ba9137
								
							
						
					
					
						commit
						d16d56ff13
					
				
							
								
								
									
										35
									
								
								mlp.py
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								mlp.py
									
									
									
									
									
								
							@ -123,11 +123,6 @@ class MultiLayerPerceptron(object):
 | 
			
		||||
        self._softmax = False
 | 
			
		||||
        if g[-1]["name"] == "softmax":
 | 
			
		||||
            self._softmax = True
 | 
			
		||||
        self._W = [None] * (L+1)
 | 
			
		||||
        self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
 | 
			
		||||
        assert(len(self._g) == len(self._W))
 | 
			
		||||
        assert(len(self._g) == len(self._b))
 | 
			
		||||
        assert(len(self._g) == len(self._n))
 | 
			
		||||
        self._A = None
 | 
			
		||||
        self._X = None
 | 
			
		||||
        self._Y = None
 | 
			
		||||
@ -141,8 +136,13 @@ class MultiLayerPerceptron(object):
 | 
			
		||||
        self._rmsprop = False
 | 
			
		||||
        self._adam = False
 | 
			
		||||
        # initialise weights
 | 
			
		||||
        self._b = [None] + [np.zeros((n[l+1], 1)) for l in range(L)]
 | 
			
		||||
        self._W = [None] + [np.zeros((n[l+1], n[l])) for l in range(L)]
 | 
			
		||||
        if set_random_w:
 | 
			
		||||
            self.init_random_weights(use_formula_w, w_rand_factor)
 | 
			
		||||
        assert(len(self._g) == len(self._W))
 | 
			
		||||
        assert(len(self._g) == len(self._b))
 | 
			
		||||
        assert(len(self._g) == len(self._n))
 | 
			
		||||
 | 
			
		||||
    def init_random_weights(self, use_formula=False, w_rand_factor=1):
 | 
			
		||||
        """Initialize randomly weights using a factor or using some formula
 | 
			
		||||
@ -229,12 +229,13 @@ class MultiLayerPerceptron(object):
 | 
			
		||||
            assert(len(X) == m)
 | 
			
		||||
            self._X = np.matrix(X).T
 | 
			
		||||
        else:
 | 
			
		||||
            #print(X.shape, self._n[0], m)
 | 
			
		||||
            print(X.shape, self._n[0], m)
 | 
			
		||||
            assert(X.shape == (self._n[0], m))
 | 
			
		||||
            self._X = X
 | 
			
		||||
        self._m = m
 | 
			
		||||
        assert((self._m == m) or (self._m == 0))
 | 
			
		||||
        self._m = m
 | 
			
		||||
        self._prepared = False
 | 
			
		||||
 | 
			
		||||
    def set_all_expected_output_examples(self, Y, m=None):
 | 
			
		||||
        """Set the output examples
 | 
			
		||||
@ -324,6 +325,28 @@ class MultiLayerPerceptron(object):
 | 
			
		||||
    def get_bias(self):
 | 
			
		||||
        return self._b[1:]
 | 
			
		||||
 | 
			
		||||
    def set_flatten_weights(self, W):
 | 
			
		||||
        """Set weights from a flatten list"""
 | 
			
		||||
        shapes = [w.shape for w in self._W[1:]]
 | 
			
		||||
        sizes = [w.size for w in self._W[1:]]
 | 
			
		||||
        flat_size = sum(sizes)
 | 
			
		||||
        assert(len(W) == flat_size)
 | 
			
		||||
        ini = 0
 | 
			
		||||
        for i, (size, shape) in enumerate(zip(sizes, shapes)):
 | 
			
		||||
            self._W[1+i] = np.reshape(W[ini:ini+size], shape)
 | 
			
		||||
            ini += size
 | 
			
		||||
 | 
			
		||||
    def set_flatten_bias(self, B):
 | 
			
		||||
        """Set bias from a flatten list"""
 | 
			
		||||
        shapes = [b.shape for b in self._b[1:]]
 | 
			
		||||
        sizes = [b.size for b in self._b[1:]]
 | 
			
		||||
        flat_size = sum(sizes)
 | 
			
		||||
        assert(len(B) == flat_size)
 | 
			
		||||
        ini = 0
 | 
			
		||||
        for i, (size, shape) in enumerate(zip(sizes, shapes)):
 | 
			
		||||
            self._b[1+i] = np.reshape(B[ini:ini+size], shape)
 | 
			
		||||
            ini += size
 | 
			
		||||
 | 
			
		||||
    def back_propagation(self, get_cost_function=False):
 | 
			
		||||
        """Back propagation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user