From 9495c7db094125735a7eca3c325d29ce56fe163b Mon Sep 17 00:00:00 2001 From: Daouzli A Date: Wed, 3 Jan 2018 14:43:31 +0100 Subject: [PATCH] add regularization and graph --- mlp.ipynb | 122 +++++++++++++++++++++++++++++++++++------------------- mlp.py | 107 +++++++++++++++++++++++++++++++---------------- 2 files changed, 151 insertions(+), 78 deletions(-) mode change 100644 => 100755 mlp.py diff --git a/mlp.ipynb b/mlp.ipynb index 20bba3a..5714a59 100644 --- a/mlp.ipynb +++ b/mlp.ipynb @@ -2,65 +2,72 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", + "try:\n", + " import matplotlib.pyplot as mp\n", + "except:\n", + " mp = None\n", + "\n", "\n", "def sigmoid(x):\n", " return 1/(1+np.exp(-x))\n", "\n", + "\n", "def deriv_sigmoid(x):\n", " a = sigmoid(x)\n", " return a * (1 - a)\n", "\n", + "\n", "def tanh(x):\n", " ep = np.exp(x)\n", " en = np.exp(-x)\n", " return (ep - en)/(ep + en)\n", "\n", + "\n", "def deriv_tanh(x):\n", " a = tanh(x)\n", " return 1 - (a * a)\n", "\n", + "\n", + "def relu(x):\n", + " ret = 0\n", + " #fixme should map to compare\n", + " if x > 0:\n", + " ret = x\n", + " elif type(x) is np.ndarray:\n", + " ret = np.zeros(x.shape)\n", + " return ret\n", + "\n", + "def deriv_relu(x):\n", + " ret = 0\n", + " if z < 0:\n", + " ret = 0.01\n", + " else:\n", + " ret = 1\n", + "\n", + "def leaky_relu(x):\n", + " ret = 0.01 * x\n", + " #fixme should map to compare\n", + " if x > 0:\n", + " ret = x\n", + " elif type(x) is np.ndarray:\n", + " ret = np.ones(x.shape)*0.01\n", + " return ret\n", + "\n", + "\n", "class MultiLayerPerceptron(object):\n", "\n", - " @staticmethod\n", - " def relu(x):\n", - " ret = 0\n", - " #fixme should map to compare\n", - " if x > 0:\n", - " ret = x\n", - " elif type(x) is np.ndarray:\n", - " ret = np.zeros(x.shape)\n", - " return ret\n", - "\n", - " @staticmethod\n", - " def deriv_relu(x):\n", - " ret = 0\n", - " if z < 0:\n", - " ret = 0.01\n", - " else:\n", - " ret = 1\n", - "\n", - " @staticmethod\n", - " def leaky_relu(x):\n", - " ret = 0.01 * x\n", - " #fixme should map to compare\n", - " if x > 0:\n", - " ret = x\n", - " elif type(x) is np.ndarray:\n", - " ret = np.ones(x.shape)*0.01\n", - " return ret\n", - "\n", " functions = {\n", " \"sigmoid\": {\"function\": sigmoid, \"derivative\": deriv_sigmoid},\n", " \"tanh\": {\"function\": tanh, \"derivative\": deriv_tanh},\n", " \"relu\": {\"function\": relu, \"derivative\": deriv_relu},\n", " }\n", "\n", - " def __init__(self, L=1, n=None, g=None, alpha=0.01):\n", + " def __init__(self, L=1, n=None, g=None, alpha=0.01, lambd=0):\n", " \"\"\"Initializes network geometry and parameters\n", " :param L: number of layers including output and excluding input. Defaut 1.\n", " :type L: int\n", @@ -92,6 +99,7 @@ " self._Z = None\n", " self._m = 0\n", " self._alpha = alpha\n", + " self._lambda = lambd\n", "\n", " def set_all_input_examples(self, X, m=1):\n", " \"\"\"Set the input examples.\n", @@ -161,6 +169,9 @@ " def get_output(self):\n", " return self._A[self._L]\n", "\n", + " def get_weights(self):\n", + " return self._W[1:]\n", + "\n", " def back_propagation(self, get_cost_function=False):\n", " \"\"\"Back propagation\n", "\n", @@ -177,8 +188,12 @@ " dA = [None] + [None] * self._L\n", " dA[l] = -self._Y/self._A[l] + ((1-self._Y)/(1-self._A[l]))\n", " if get_cost_function:\n", + " wnorms = 0\n", + " for w in self._W[1:]:\n", + " wnorms += np.linalg.norm(w)\n", " J = -1/m * ( np.dot(self._Y, np.log(self._A[l]).T) + \\\n", - " np.dot((1 - self._Y), np.log(1-self._A[l]).T) )\n", + " np.dot((1 - self._Y), np.log(1-self._A[l]).T) ) + \\\n", + " self._lambda/(2*m) * wnorms # regularization\n", "\n", " #dZ = self._A[l] - self._Y\n", " for l in range(self._L, 0, -1):\n", @@ -190,12 +205,13 @@ "# dW[l] = 1/m * np.dot(dZ, self._A[l-1].T)\n", "# db[l] = 1/m * np.sum(dZ, axis=1, keepdims=True)\n", " for l in range(self._L, 0, -1):\n", - " self._W[l] = self._W[l] - self._alpha * dW[l]\n", + " self._W[l] = self._W[l] - self._alpha * dW[l] - \\\n", + " (self._alpha*self._lambda/m * self._W[l]) # regularization\n", " self._b[l] = self._b[l] - self._alpha * db[l]\n", "\n", " return J\n", "\n", - " def minimize_cost(self, min_cost, max_iter=100000, alpha=None):\n", + " def minimize_cost(self, min_cost, max_iter=100000, alpha=None, plot=False):\n", " \"\"\"Propagate forward then backward in loop while minimizing the cost function.\n", "\n", " :param min_cost: cost function value to reach in order to stop algo.\n", @@ -207,15 +223,24 @@ " if alpha is None:\n", " alpha = self._alpha\n", " self.propagate()\n", + " if plot:\n", + " y=[]\n", + " x=[]\n", " for i in range(max_iter):\n", " J = self.back_propagation(True)\n", + " if plot:\n", + " y.append(J[0][0])\n", + " x.append(nb_iter)\n", " self.propagate()\n", " nb_iter = i + 1\n", " if J <= min_cost:\n", " break\n", + " if mp and plot:\n", + " mp.plot(x,y)\n", + " mp.show()\n", " return {\"iterations\": nb_iter, \"cost_function\": J}\n", "\n", - " def learning(self, X, Y, m, min_cost=0.05, max_iter=100000, alpha=None):\n", + " def learning(self, X, Y, m, min_cost=0.05, max_iter=100000, alpha=None, plot=False):\n", " \"\"\"Tune parameters in order to learn examples by propagate and backpropagate.\n", "\n", " :param X: the inputs training examples\n", @@ -228,27 +253,39 @@ " \"\"\"\n", " self.set_all_training_examples(X, Y, m)\n", " self.prepare()\n", - " res = self.minimize_cost(min_cost, max_iter, alpha)\n", - " return res\n", - " " + " res = self.minimize_cost(min_cost, max_iter, alpha, plot)\n", + " return res\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 109, "metadata": {}, "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3Xl4lfWd9/H3N/u+EAKBEPYdRaQR17pb0VZs6zwt2s2pj3ZmpNpx2o480+nj2OnYcWbaOjO0U2vttLXKKFWLS6V1X6pIQERJCEbWAIEACQlkT77zxznQGLMc4CRnyed1XedK7vv8cs4nXMePd+7tZ+6OiIjEn4RIBxARkcGhghcRiVMqeBGROKWCFxGJUyp4EZE4pYIXEYlTKngRkTgVUsGb2UIzqzSzKjO7vZfnJ5jZc2a2wcxeNLNx4Y8qIiLHwwa60MnMEoHNwGVANbAGuNbdy7uNeQR40t1/YWYXA3/u7l8YvNgiIjKQpBDGLACq3H0LgJktB64GyruNmQ3cFvz+BeDxgV505MiRPnHixOMKKyIy3K1du3a/uxeGMjaUgi8GdnZbrgbO7DHmbeDTwD3Ap4BsMytw9wN9vejEiRMpKysLJaOIiASZ2fZQx4brIOvXgQvM7C3gAmAX0NlLsJvMrMzMympra8P01iIi0ptQCn4XUNJteVxw3THuvtvdP+3upwN/F1xX3/OF3P1edy9199LCwpD+whARkRMUSsGvAaaZ2SQzSwEWAyu7DzCzkWZ29LWWAveHN6aIiByvAQve3TuAJcAqoAJ42N03mtmdZrYoOOxCoNLMNgOjge8OUl4REQnRgKdJDpbS0lLXQVYRkeNjZmvdvTSUsbqSVUQkTqngRUTiVMwV/JptB7n7mU10dWmqQRGR/sRcwb+9s54fvfg+ja0dkY4iIhLVYq7gc9KTATjU1B7hJCIi0S3mCj7vaME3q+BFRPoTcwWfq4IXEQlJzBV8XkYKAPXNbRFOIiIS3WKu4LUFLyISmpgt+HodZBUR6VfMFXxacgIpSQk0aAteRKRfMVfwZkZuerK24EVEBhBzBQ+BUyW1D15EpH8xWfC5KngRkQHFZMHnZSRTr4IXEelXSAVvZgvNrNLMqszs9l6eH29mL5jZW2a2wcyuDH/UP8lJT9ZBVhGRAQxY8GaWCCwDrgBmA9ea2ewew75FYKan0wlM6fejcAftLnCQVRc6iYj0J5Qt+AVAlbtvcfc2YDlwdY8xDuQEv88Fdocv4oflpadwpK2T9s6uwXwbEZGYFkrBFwM7uy1XB9d1dwfweTOrBp4GvhqWdH3ITU8CdDWriEh/wnWQ9Vrgv919HHAl8Csz+9Brm9lNZlZmZmW1tbUn/GZH70ejghcR6VsoBb8LKOm2PC64rrsbgIcB3P11IA0Y2fOF3P1edy9199LCwsITS4zuRyMiEopQCn4NMM3MJplZCoGDqCt7jNkBXAJgZrMIFPyJb6IPIDdDk36IiAxkwIJ39w5gCbAKqCBwtsxGM7vTzBYFh/0NcKOZvQ08BFzv7oM2aaq24EVEBpYUyiB3f5rAwdPu677d7fty4NzwRutb3rE7SupUSRGRvsTklazH5mVt1sTbIiJ9icmCT05MIDMlUbM6iYj0IyYLHgKnSmofvIhI32K24HU/GhGR/sVswedp0g8RkX7FbMHrnvAiIv2L2YLXPeFFRPoXswWvLXgRkf7FbMHnpCfT1tFFS3tnpKOIiESlmC34vIyjV7NqK15EpDcxW/C6H42ISP9ituDz0gP3hNf9aEREehezBa8teBGR/sVswR/bB6+CFxHpVcwW/NE7Sup2BSIivQup4M1soZlVmlmVmd3ey/M/MLP1wcdmM6sPf9QPyk5Nwky7aERE+jLghB9mlggsAy4DqoE1ZrYyOMkHAO7+193GfxU4fRCyfkBCgpGr+9GIiPQplC34BUCVu29x9zZgOXB1P+OvJTBt36DT1awiIn0LpeCLgZ3dlquD6z7EzCYAk4DnTz7awPLSdT8aEZG+hPsg62Jghbv3ev8AM7vJzMrMrKy2tvak3yxHW/AiIn0KpeB3ASXdlscF1/VmMf3snnH3e9291N1LCwsLQ0/Zh7yMFA7pQicRkV6FUvBrgGlmNsnMUgiU+Mqeg8xsJpAPvB7eiH3LTU/SFryISB8GLHh37wCWAKuACuBhd99oZnea2aJuQxcDy93dByfqh+WlB+Zl7eoasrcUEYkZA54mCeDuTwNP91j37R7Ld4QvVmhy05Ppcjjc1kFOWvJQv72ISFSL2StZodv9aHQuvIjIh8R2wWfohmMiIn2J7YLXHSVFRPoU0wWvWZ1ERPoW0wWvLXgRkb7FdMEfndVJBS8i8mExXfBpyQmkJCZQ36yrWUVEeorpgjczctKTNemHiEgvYrrgIXCgVQdZRUQ+LOYLXveEFxHpXcwXfJ5mdRIR6VXMF7y24EVEehf7BZ+hg6wiIr2J/YJPT6axtYOOzq5IRxERiSpxUfAADS0dEU4iIhJdYr7g/3Q/Gl3sJCLSXUgFb2YLzazSzKrM7PY+xnzGzMrNbKOZPRjemH3T/WhERHo34IxOZpYILAMuA6qBNWa20t3Lu42ZBiwFznX3OjMbNViBe8rV/WhERHoVyhb8AqDK3be4exuwHLi6x5gbgWXuXgfg7vvCG7Nv2oIXEeldKAVfDOzstlwdXNfddGC6mb1mZm+Y2cLeXsjMbjKzMjMrq62tPbHEPeRpVicRkV6F6yBrEjANuBC4FvipmeX1HOTu97p7qbuXFhYWhuWNj27B62pWEZEPCqXgdwEl3ZbHBdd1Vw2sdPd2d98KbCZQ+IMuOTGBzJREahtbh+LtRERiRigFvwaYZmaTzCwFWAys7DHmcQJb75jZSAK7bLaEMWe/zp06kt+sq2ZXffNQvaWISNQbsODdvQNYAqwCKoCH3X2jmd1pZouCw1YBB8ysHHgB+Ia7Hxis0D19+6rZuMPfP/4u7j5UbysiEtUsUoVYWlrqZWVlYXu9+17Zwj8+VcGy6+bz8bljwva6IiLRxMzWuntpKGNj/krWo64/ZyKnFOdwxxMbdUaNiAhxVPBJiQl879NzOXC4lX9+ZlOk44iIRFzcFDzAKcW5fPncSTy4egdrth2MdBwRkYiKq4IH+OvLplOcl87SR9+hoUW7akRk+Iq7gs9MTeKuT5/Ktv1HuOZHf2TnwaZIRxIRiYi4K3iA86cX8ssvL2BvQwufXPYaa7fXRTqSiMiQi8uCBzhn6kgeu/lcstKSuPanb/Db9T0vvhURiW9xW/AAUwqzePyvzmVeSR63Ll/P3c9s0tR+IjJsxHXBA+RnpvCrGxaw+IwSfvTi+3z23jeortN+eRGJf3Ff8ACpSYl875q53LN4HpU1jVxxzys8tWFPpGOJiAyqYVHwR109r5inb/kokwuzuPnBdSx9dANHWjVZt4jEp2FV8ADjCzJY8Rdn85cXTmH5mp187Acv8/Lm8Ew+IiISTYZdwUPgHvJ/u3Amj3zlbNKSE/ji/W/y9Ufepr6pLdLRRETCZlgW/FGlE0fw1C0fZclFU3nsrV1c+v2XeXLDbt1yWETiwrAueIC05ES+fvkMVi45l6LcVJY8+BbX/XQ1lTWNkY4mInJSQip4M1toZpVmVmVmt/fy/PVmVmtm64OP/xv+qINrzthcHv+rc/nOJ0+hoqaBK//9Fe5YuZFDmutVRGLUgBN+mFkigTlWLyMw9+oa4Fp3L+825nqg1N2XhPrG4Z7wI5zqjrTxb3+o5MHVO8jLSOGWi6dy3ZkTSEka9n/wiEiEhXvCjwVAlbtvcfc2YDlw9ckEjHb5mSn84ydP5Ymvnsf00Vnc8UQ5l3z/RR5/axddXdo/LyKxIZSCLwZ2dluuDq7r6Roz22BmK8ysJCzpImzO2FweuvEs/vvPzyA7NZmv/c96rvz3V3iuYq8OxIpI1AvXPocngInuPhf4A/CL3gaZ2U1mVmZmZbW1sXHuuZlx4YxRPPnV87hn8Tya2jq54RdlXPWfr/LMuzXaoheRqBXKPvizgTvc/fLg8lIAd7+rj/GJwEF3z+3vdaN5H3x/2ju7eGzdLpa9WMX2A03MGJ3NkouncuWpY0hMsEjHE5E4F+598GuAaWY2ycxSgMXAyh5vOKbb4iKgItSwsSY5MYHPnFHCc7ddwA8+exodXV189aG3uPBfX+Dnr23VrQ9EJGoMuAUPYGZXAj8EEoH73f27ZnYnUObuK83sLgLF3gEcBP7S3fud+TpWt+B76uxyfr+xhvte3cra7XVkpyVx3Znjuf6ciYzJTY90PBGJM8ezBR9SwQ+GeCn47tbtqONnr2zld+/uwcy4dNYoPn/WBM6dMpIE7b4RkTA4noJPGuwww8n88fnM/1w+Ow828cDq7TxSVs2qjXuZWJDBdWeO55r54yjISo10TBEZJrQFP4haOzp55t0aHnhjO2u21ZGUYFw8cxR/9pFxXDRzFMmJunBKRI6PtuCjRGpSIlfPK+bqecVs3tvIirXVPLpuF78v30tBZgpXzytm0byxnDYuFzPtwhGR8NIW/BDr6Ozipc21rFhbzXMV+2jr7GL8iAyuOm0MV502lplFOZGOKCJRTAdZY8Sh5nZWbazhibd381rVfrocJhdmsnBOEZfPKWKutuxFpAcVfAzaf7iV372zh2c21vDGloN0djljc9O4bPZoLp41mjMnjSAtOTHSMUUkwlTwMa7uSBvPbdrHqo01vLy5ltaOLtKTEzl36kgunjmKC2YUUpync+xFhiMVfBxpae/k9fcP8PymfTy/aR+76psBmDwyk/OmjeS8qSM5e0oB2WnJEU4qIkNBBR+n3J339h3m5c21vFq1n9VbDtLc3klignHK2BzOmlzAWZMLKJ2Yr8IXiVMq+GGitaOTddvrea1qP6u3HmD9znraO50EC9zq+CMT8imdmE/phBEU5aZFOq6IhIEKfphqbuvkrR11vLHlAGu21bF+Zz3N7Z0AFOelc/r4POaV5HH6+DzmjM3VQVuRGKQLnYap9JREzpk6knOmjgQCtzYu391A2fY61m2v460d9Ty5YQ8ASQnGzDHZnFqcy6nFeZxanMuMomxNSygSR7QFP8zsa2hh/c561u+sZ0P1ITZU19PQErjFcXKiMX10NrPH5DBnbA6zx+Yya0y29ueLRBHtopGQuTs7Dzbzzq5DbNhVT/nuBsp3N3DgSNuxMSUj0plZlMOsomxmjclhRlE2EwoyNcGJSARoF42EzMwYX5DB+IIMPj43MG+Lu7OvsZWNuw9RvruBTTWNVOxp4LmKvRydoTA1KYFpo7OYMTqHGUVZzCjKYcbobEbnpOrqW5EoEVLBm9lC4B4CE37c5+7f62PcNcAK4Ax31+Z5jDIzRuekMTonjYtnjj62vqW9k/f2HmZTTQOVNY1U7m3klfdq+c266mNjctKSmFGUHXwEtvqnF2WTo908IkNuwIIPzrG6DLgMqAbWmNlKdy/vMS4buBVYPRhBJfLSkhM5dVwup4774HS7dUfa2Lw3UPiVNYHHb9fvprFlx7ExxXnpzB4b2Lc/Z2wuc8bmMCY3TVv7IoMolC34BUCVu28BMLPlwNVAeY9x3wH+GfhGWBNK1MvPTOHMyQWcObng2Dp3Z/ehFiprGqjYE9jFU76ngWcr9nL0sM/IrFTmj89j/oR8Ti/JY+64PNJTdOqmSLiEUvDFwM5uy9XAmd0HmNl8oMTdnzIzFbxgZhTnpVOcl/6B3TxHWjvYVNPAxt0NrN9Rz7oddfy+fC8QOIvnjIkjuHBGIRfNGMXUUVnawhc5CSd9kNXMEoDvA9eHMPYm4CaA8ePHn+xbSwzKTE3iIxNG8JEJI/ji2YF1B4+08daOOt7cepAXK2v5p6c38U9Pb6I4L50rTy3iC2dNZHxBRmSDi8SgAU+TNLOzgTvc/fLg8lIAd78ruJwLvA8cDv5IEXAQWNTfgVadJil92V3fzIuVtTy/aR8vVu6j052LZozii2dP4PxphZrAXIa1sJ4Hb2ZJwGbgEmAXsAa4zt039jH+ReDrA51Fo4KXUNQcauHBN3fw4Ood7D/cyuSRmfzjp07hnCkjIx1NJCKOp+AHvC7d3TuAJcAqoAJ42N03mtmdZrbo5KKK9K8oN43bLpvOH2+/mHsWzwPgc/et5q6nK2jr6IpwOpHopitZJaY0tXXw3acq+PXqHcwZm8M9i+cxdVR2pGOJDJmwbsGLRJOMlCS++6lT+ekXS9lzqIVP/MerrFhbPfAPigxDKniJSZfNHs0zX/so88fn880Vb/Pqe/sjHUkk6qjgJWaNyk7jp18sZeqoLJY8tI6dB5siHUkkqqjgJaZlpibxky+U0tnlfOVXa2lu64x0JJGooYKXmDdpZCb3LJ5HRU0DSx/dQKROHBCJNip4iQsXzxzNbZdO5/H1u/n5a9siHUckKqjgJW7cfNFUPjZ7NN99uoK3dtRFOo5IxKngJW4kJBj/9pnTyM9I4a7fbdKuGhn2VPASV7LTkrnlkqm8ufUgL22ujXQckYhSwUvcWXzGeEpGpHP3M5V0dWkrXoYvFbzEnZSkBP7mshmU72ngyXf2RDqOSMSo4CUuLTptLDOLsvn+7ytp79RNyWR4UsFLXEpIML5x+Qy2HWji4bKdA/+ASBxSwUvcunjmKEon5HPPs+/pClcZllTwErfMjG8unMm+xlZ+8fq2SMcRGXIhFbyZLTSzSjOrMrPbe3n+L8zsHTNbb2avmtns8EcVOX4LJo3gohmF/PjF92lsaY90HJEhNWDBm1kisAy4ApgNXNtLgT/o7qe6+zzgbgKTcItEha9dOp1Dze388vXtkY4iMqRC2YJfAFS5+xZ3bwOWA1d3H+DuDd0WMwGdfCxR47SSPC6cUch9r2zhSGtHpOOIDJlQCr4Y6H4aQnVw3QeY2c1m9j6BLfhbwhNPJDxuuWQadU3tPPCGtuJl+AjbQVZ3X+buU4C/Bb7V2xgzu8nMysysrLZWl5HL0Jk/Pp+PThvJvS9v0Rk1MmyEUvC7gJJuy+OC6/qyHPhkb0+4+73uXurupYWFhaGnFAmDWy+ZxoEjbfx6tbbiZXgIpeDXANPMbJKZpQCLgZXdB5jZtG6LHwfeC19EkfAonTiCc6YU8F8vbaGlXVvxEv8GLHh37wCWAKuACuBhd99oZnea2aLgsCVmttHM1gO3AV8atMQiJ+GWS6ax/3ArD725I9JRRAadReqe2aWlpV5WVhaR95bh7bM/eZ1tB47w0jcuIi05MdJxRI6Lma1199JQxupKVhl2br1kGnsbWvn1am3FS3xTwcuwc/aUAs6bOpJ/f+496pvaIh1HZNCo4GXYMTO+9YlZNLa088NndT6AxC8VvAxLM4ty+OwZ43ngje28X3s40nFEBoUKXoat2y6bTlpyInc9XRHpKCKDQgUvw1Zhdio3XzSVZyv28VrV/kjHEQk7FbwMa39+7kTG5afznSfL6dQE3RJnVPAyrKUlJ7L0illsqmnU1H4Sd1TwMuxdeWoRZ0zM519XVXLgcGuk44iEjQpehj0z4zufPIXGlg6+9fi7ROrqbpFwU8GLEDht8raPTed379bw2/W7Ix1HJCxU8CJBN350MqUT8vn7377LnkPNkY4jctJU8CJBiQnGv33mNDq7nG+u2KBdNRLzVPAi3UwoyOT/XTmLV97bzwO6GZnEOBW8SA+fO3M8508v5J+eqmDr/iORjiNywlTwIj2YGXdfM5eUpAT+4ldrOdzaEelIIickpII3s4VmVmlmVWZ2ey/P32Zm5Wa2wcyeM7MJ4Y8qMnSKctNYdt18qmoPc+tDb+kqV4lJAxa8mSUCy4ArgNnAtWY2u8ewt4BSd58LrADuDndQkaF23rSR3HHVbJ7btI+7V22KdByR4xbKFvwCoMrdt7h7G7AcuLr7AHd/wd2bgotvAOPCG1MkMr5w9kQ+f9Z4fvLSFn6ztjrScUSOSygFXwx0v0lHdXBdX24AftfbE2Z2k5mVmVlZbW1t6ClFIuj/XzWHc6YUsPTRd1i7/WCk44iELKwHWc3s80Ap8C+9Pe/u97p7qbuXFhYWhvOtRQZNcmICP/rcfMbkpXHTL9dSta8x0pFEQhJKwe8CSrotjwuu+wAzuxT4O2CRu+uOTRJX8jJSuP/6MzAzFt+7WiUvMSGUgl8DTDOzSWaWAiwGVnYfYGanAz8hUO77wh9TJPKmFGax/KazAFTyEhMGLHh37wCWAKuACuBhd99oZnea2aLgsH8BsoBHzGy9ma3s4+VEYtrUUSp5iR0WqfttlJaWellZWUTeW+RkVe07zOJ73wDgoRvPZNro7AgnkuHCzNa6e2koY3Ulq8gJOLolbwaf/vEfeeU9nRUm0UcFL3KCpo7K4rG/OofivHSu//kafvn6tkhHEvkAFbzISRiXn8GKvzyHi2YU8u3fbuRbj79De2dXpGOJACp4kZOWlZrET75QylcumMwDb+zgS/e/SW2jzhSWyFPBi4RBYoKx9IpZ/MufzWXt9joW/vBl/lC+N9KxZJhTwYuE0f8pLeGJr57H6Jw0bvxlGUsffYemNt1uWCJDBS8SZtNHZ/PYzefwlQsms3zNDq685xXdw0YiQgUvMghSkxJZesUsHrrxLNo7nWt+/DrfeORt9h/WvnkZOip4kUF01uQCVv31+Xzlgsk8vn4XF/3ri/z8ta106EwbGQIqeJFBlpWaxNIrZvG7W89nXkke//BEOZ/4j1d5rmIvkbqSXIYHFbzIEJk6KotffnkB//X5j9Dc3skNvyjjk8te48XKfSp6GRQqeJEhZGYsPKWIZ2+7gLuvmcv+w21c//M1XPPjP6roJex0szGRCGrr6OKRtTv5z+er2HOohamjsvjyuZP41OnFpKckRjqeRKHjudmYCl4kCrR1dPHkht387NWtbNzdQH5GMtcuGM+1C8ZTMiIj0vEkiqjgRWKUu/Pm1oPc/9pWfl++F3c4Z0oBnyktYeEpRaQla6t+uAt7wZvZQuAeIBG4z92/1+P584EfAnOBxe6+YqDXVMGL9G9XfTO/WVvNirXV7DjYRHZqEh+fO4aPzx3D2ZMLSErUIbThKKwFb2aJwGbgMqCawBR+17p7ebcxE4Ec4OvAShW8SPh0dTmrtx7kkbU7WfVuDUfaOhmRmcLlc0bz8VPHcubkESSr7IeN4yn4pBDGLACq3H1L8MWXA1cDxwre3bcFn9PVGyJhlpBgnD2lgLOnFNDyqU5e2lzLUxv2sHL9bh56cyfZaUmcP62Qi2eO4sIZhRRkpUY6skSJUAq+GNjZbbkaOHNw4ohIf9KSE7l8ThGXzymipb2TlzfX8vymfTy/aR9PvbMHM5g7Lo/zphZw7pSRzJ+Qr/32w1goBR82ZnYTcBPA+PHjh/KtReJOWnIiH5tTxMfmFNHV5ZTvaeDZir288t5+/uulLSx74X1SkxIonZjPWZMKKJ04gnkleTr9chgJpeB3ASXdlscF1x03d78XuBcC++BP5DVE5MMSEoxTinM5pTiXr106ncaWdt7cepDXqg7wx/f38/1nN+MOSQnGnOJcSifkc1pJHvPG5VEyIh0zi/SvIIMglIJfA0wzs0kEin0xcN2gphKRk5Kdlswls0ZzyazRABxqamfdjjrWbDtI2bY6HnhjOz97dSsA+RnJzB2Xx6nFucwZm8PssTmMH5Gh0o8DAxa8u3eY2RJgFYHTJO93941mdidQ5u4rzewM4DEgH7jKzP7B3ecManIRCVluRjIXzRzFRTNHAdDe2UVlTSNvV9fz9s56NlQf4tWq/XR2Bf6wzk5NYtaYHKYXZTFjdDbTR2czoyibvIyUSP4acpx0oZOIANDS3snmvY1s3N1A+e4Gyvc0sHlvI40tf5qRamRWKlMKM5kyKosphVlMLsxkUkEm4/LTdV7+EAn3aZIiMgykJScyd1wec8flHVvn7tQ0tFBZ00hlTSPv1x7m/dojPLVhD4ea24+NS0owSkZkMKEgg4nBwh8/IoOS4CMrVVUTCfpXF5E+mRljctMZk5vOhTNGHVvv7hw80sb7tUfYduAI2/YHvm7d30TZtjoOt35wHtq8jGTG5qZTnJ9OcV7gMSYvjTG5aYzJTWdUdqr+AhgEKngROW5mRkFWKgVZqSyYNOIDz7k79U3t7DjYxM66JnYebGZXfRO76prZfuAIf6zaz5G2zg/8TIIFdv+MzkljdE7g66jsNAqzUxmVnUph8FGQlUJqkk7zDJUKXkTCyszIz0whPzOF00ryPvS8u9PQ0kHNoRb2HGpmz6EW9hxqYV9DCzUNLeyqb2HdjnoOHmnr9fVz0pIYmZXKyKxA4RdkpTAiM5WCzBRGBB/5GYGveRnJw/pCLxW8iAwpMyM3PZnc9GRmFGX3Oa6to4v9h1upbQw89jW2sv9wKwcOt7L/cBu1h1vZvLeRg1vaqG9up6/zRdKTE8nPSCYvI1D4+Rkp5ATfPy8j+ViWnLTg1/QkctKSyU5LivndRip4EYlKKUkJjM1LZ2xe+oBjOzq7qG9u58DhNuqa2qg70kZdUzt1TW0cPNJGfVM7h5oD6ypqGmhobudQczvtnf2fRZiRkkh2WhLZacnkpCWRFSz+7NQkslKTyEoLfM1OSyIzuC4zNYnMlKPfJ5KZmkRqUkJEritQwYtIzEtKTDi22yZU7k5zeyeHmtupb2qnsaWDQ83tx8q/saWDxpbA14aWPz1fXdfE4ZYOGls6aG7vHPiNgMQEIyMlkcyUJDJSE/napdNZdNrYE/11Q6aCF5FhyczISEkiIyWJMbkD/5XQm47OLo60dXK4tYMjrYHSb2oLfH+4tTP4tYPmtk6OtHXQ1Br4mp+RHObfpncqeBGRE5SUmEBuegK56UNT2Mcrto8giIhIn1TwIiJxSgUvIhKnVPAiInFKBS8iEqdU8CIicUoFLyISp1TwIiJxKmIzOplZLbD9BH98JLA/jHGGSizmjsXMEJu5lXnoxGLuo5knuHthKD8QsYI/GWZWFuqUVdEkFnPHYmaIzdzKPHRiMfeJZNYuGhGROKWCFxGJU7Fa8PdGOsAJisXcsZgZYjO3Mg+dWMx93Jljch+8iIgMLFa34EVEZAAxV/BmttDKPzwsAAAD2klEQVTMKs2sysxuj3SevpjZ/Wa2z8ze7bZuhJn9wczeC37Nj2TGnsysxMxeMLNyM9toZrcG10dtbjNLM7M3zeztYOZ/CK6fZGarg5+T/zGzlEhn7cnMEs3sLTN7MrgcC5m3mdk7ZrbezMqC66L28wFgZnlmtsLMNplZhZmdHQOZZwT/jY8+Gszsa8ebO6YK3swSgWXAFcBs4Fozmx3ZVH36b2Bhj3W3A8+5+zTgueByNOkA/sbdZwNnATcH/32jOXcrcLG7nwbMAxaa2VnAPwM/cPepQB1wQwQz9uVWoKLbcixkBrjI3ed1O2Uvmj8fAPcAz7j7TOA0Av/mUZ3Z3SuD/8bzgI8ATcBjHG9ud4+ZB3A2sKrb8lJgaaRz9ZN3IvBut+VKYEzw+zFAZaQzDpD/t8BlsZIbyADWAWcSuCAkqbfPTTQ8gHHB/0AvBp4ELNozB3NtA0b2WBe1nw8gF9hK8HhjLGTu5Xf4GPDaieSOqS14oBjY2W25OrguVox29z3B72uA0ZEM0x8zmwicDqwmynMHd3WsB/YBfwDeB+rdvSM4JBo/Jz8Evgl0BZcLiP7MAA783szWmtlNwXXR/PmYBNQCPw/uDrvPzDKJ7sw9LQYeCn5/XLljreDjhgf+FxyVpzCZWRbwG+Br7t7Q/blozO3unR74U3YcsACYGeFI/TKzTwD73H1tpLOcgPPcfT6B3aQ3m9n53Z+Mws9HEjAf+LG7nw4cocdujSjMfEzwOMwi4JGez4WSO9YKfhdQ0m15XHBdrNhrZmMAgl/3RTjPh5hZMoFy/7W7PxpcHfW5Ady9HniBwO6NPDM7Oql8tH1OzgUWmdk2YDmB3TT3EN2ZAXD3XcGv+wjsE15AdH8+qoFqd18dXF5BoPCjOXN3VwDr3H1vcPm4csdawa8BpgXPNkgh8KfLyghnOh4rgS8Fv/8SgX3cUcPMDPgZUOHu3+/2VNTmNrNCM8sLfp9O4JhBBYGi/7PgsKjK7O5L3X2cu08k8Bl+3t0/RxRnBjCzTDPLPvo9gX3D7xLFnw93rwF2mtmM4KpLgHKiOHMP1/Kn3TNwvLkjfQDhBA44XAlsJrCf9e8inaefnA8Be4B2AlsRNxDYz/oc8B7wLDAi0jl7ZD6PwJ98G4D1wceV0ZwbmAu8Fcz8LvDt4PrJwJtAFYE/b1MjnbWP/BcCT8ZC5mC+t4OPjUf/+4vmz0cw3zygLPgZeRzIj/bMwdyZwAEgt9u648qtK1lFROJUrO2iERGREKngRUTilApeRCROqeBFROKUCl5EJE6p4EVE4pQKXkQkTqngRUTi1P8C/28y0Jq1QWsAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "{'iterations': 62, 'cost_function': array([[ 0.04879932]])}\n", - "[[ 0.03621064 0.94089041 0.94055051 0.03022811]]\n" + "{'iterations': 68, 'cost_function': array([[ 0.04958664]])}\n", + "[[ 0.01220273 0.95827161 0.95016051 0.05815676]]\n", + "[array([[-3.29388938, 2.64675614],\n", + " [ 1.70522303, -2.64034303],\n", + " [-2.2292173 , -1.73490183]]), array([[ 3.5237224 , 3.48321649, -2.70213352]])]\n" ] } ], "source": [ - "mlp = MultiLayerPerceptron(L=2, n=[2, 2, 1], g=[\"tanh\", \"sigmoid\"], alpha=2)\n", + "mlp = MultiLayerPerceptron(L=2, n=[2, 3, 1], g=[\"tanh\", \"sigmoid\"], alpha=2, lambd=0.005)\n", "#mlp = MultiLayerPerceptron(L=1, n=[2, 1], g=[\"sigmoid\"], alpha=0.1)\n", "\n", "X = np.array([[0, 0],\n", @@ -261,9 +298,10 @@ " [1],\n", " [0]])\n", "\n", - "res = mlp.learning(X.T, Y.T, 4)\n", + "res = mlp.learning(X.T, Y.T, 4, max_iter=5000, plot=True)\n", "print(res)\n", "print(mlp.get_output())\n", + "print(mlp.get_weights())\n", "#mlp.set_all_training_examples(X.T, Y.T, 4)\n", "#mlp.prepare()\n", "#print(mlp.propagate())\n", diff --git a/mlp.py b/mlp.py old mode 100644 new mode 100755 index d79a5f5..978f1e2 --- a/mlp.py +++ b/mlp.py @@ -1,58 +1,67 @@ +#!/usr/bin/env python3 + import numpy as np +try: + import matplotlib.pyplot as mp +except: + mp = None + def sigmoid(x): return 1/(1+np.exp(-x)) + def deriv_sigmoid(x): a = sigmoid(x) return a * (1 - a) + def tanh(x): ep = np.exp(x) en = np.exp(-x) return (ep - en)/(ep + en) + def deriv_tanh(x): a = tanh(x) return 1 - (a * a) + +def relu(x): + ret = 0 + #fixme should map to compare + if x > 0: + ret = x + elif type(x) is np.ndarray: + ret = np.zeros(x.shape) + return ret + +def deriv_relu(x): + ret = 0 + if z < 0: + ret = 0.01 + else: + ret = 1 + +def leaky_relu(x): + ret = 0.01 * x + #fixme should map to compare + if x > 0: + ret = x + elif type(x) is np.ndarray: + ret = np.ones(x.shape)*0.01 + return ret + + class MultiLayerPerceptron(object): - @staticmethod - def relu(x): - ret = 0 - #fixme should map to compare - if x > 0: - ret = x - elif type(x) is np.ndarray: - ret = np.zeros(x.shape) - return ret - - @staticmethod - def deriv_relu(x): - ret = 0 - if z < 0: - ret = 0.01 - else: - ret = 1 - - @staticmethod - def leaky_relu(x): - ret = 0.01 * x - #fixme should map to compare - if x > 0: - ret = x - elif type(x) is np.ndarray: - ret = np.ones(x.shape)*0.01 - return ret - functions = { "sigmoid": {"function": sigmoid, "derivative": deriv_sigmoid}, "tanh": {"function": tanh, "derivative": deriv_tanh}, "relu": {"function": relu, "derivative": deriv_relu}, } - def __init__(self, L=1, n=None, g=None, alpha=0.01): + def __init__(self, L=1, n=None, g=None, alpha=0.01, lambd=0): """Initializes network geometry and parameters :param L: number of layers including output and excluding input. Defaut 1. :type L: int @@ -84,6 +93,7 @@ class MultiLayerPerceptron(object): self._Z = None self._m = 0 self._alpha = alpha + self._lambda = lambd def set_all_input_examples(self, X, m=1): """Set the input examples. @@ -153,6 +163,9 @@ class MultiLayerPerceptron(object): def get_output(self): return self._A[self._L] + def get_weights(self): + return self._W[1:] + def back_propagation(self, get_cost_function=False): """Back propagation @@ -169,8 +182,12 @@ class MultiLayerPerceptron(object): dA = [None] + [None] * self._L dA[l] = -self._Y/self._A[l] + ((1-self._Y)/(1-self._A[l])) if get_cost_function: + wnorms = 0 + for w in self._W[1:]: + wnorms += np.linalg.norm(w) J = -1/m * ( np.dot(self._Y, np.log(self._A[l]).T) + \ - np.dot((1 - self._Y), np.log(1-self._A[l]).T) ) + np.dot((1 - self._Y), np.log(1-self._A[l]).T) ) + \ + self._lambda/(2*m) * wnorms # regularization #dZ = self._A[l] - self._Y for l in range(self._L, 0, -1): @@ -182,12 +199,13 @@ class MultiLayerPerceptron(object): # dW[l] = 1/m * np.dot(dZ, self._A[l-1].T) # db[l] = 1/m * np.sum(dZ, axis=1, keepdims=True) for l in range(self._L, 0, -1): - self._W[l] = self._W[l] - self._alpha * dW[l] + self._W[l] = self._W[l] - self._alpha * dW[l] - \ + (self._alpha*self._lambda/m * self._W[l]) # regularization self._b[l] = self._b[l] - self._alpha * db[l] return J - def minimize_cost(self, min_cost, max_iter=100000, alpha=None): + def minimize_cost(self, min_cost, max_iter=100000, alpha=None, plot=False): """Propagate forward then backward in loop while minimizing the cost function. :param min_cost: cost function value to reach in order to stop algo. @@ -199,15 +217,24 @@ class MultiLayerPerceptron(object): if alpha is None: alpha = self._alpha self.propagate() + if plot: + y=[] + x=[] for i in range(max_iter): J = self.back_propagation(True) + if plot: + y.append(J[0][0]) + x.append(nb_iter) self.propagate() nb_iter = i + 1 if J <= min_cost: break + if mp and plot: + mp.plot(x,y) + mp.show() return {"iterations": nb_iter, "cost_function": J} - def learning(self, X, Y, m, min_cost=0.05, max_iter=100000, alpha=None): + def learning(self, X, Y, m, min_cost=0.05, max_iter=100000, alpha=None, plot=False): """Tune parameters in order to learn examples by propagate and backpropagate. :param X: the inputs training examples @@ -220,12 +247,12 @@ class MultiLayerPerceptron(object): """ self.set_all_training_examples(X, Y, m) self.prepare() - res = self.minimize_cost(min_cost, max_iter, alpha) + res = self.minimize_cost(min_cost, max_iter, alpha, plot) return res if __name__ == "__main__": - mlp = MultiLayerPerceptron(L=2, n=[2, 3, 1], g=["tanh", "sigmoid"], alpha=2) + mlp = MultiLayerPerceptron(L=2, n=[2, 3, 1], g=["tanh", "sigmoid"], alpha=2, lambd=0.005) #mlp = MultiLayerPerceptron(L=1, n=[2, 1], g=["sigmoid"], alpha=0.1) X = np.array([[0, 0], @@ -238,7 +265,15 @@ if __name__ == "__main__": [1], [0]]) - res = mlp.learning(X.T, Y.T, 4) + res = mlp.learning(X.T, Y.T, 4, max_iter=5000, plot=True) print(res) print(mlp.get_output()) + print(mlp.get_weights()) + #mlp.set_all_training_examples(X.T, Y.T, 4) + #mlp.prepare() + #print(mlp.propagate()) + #for i in range(100): + # print(mlp.back_propagation()) + # mlp.propagate() + #print(mlp.propagate())