You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
14 KiB
Plaintext

7 years ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from __future__ import division, print_function\n",
"from sympy import Symbol, diff, solve, lambdify, simplify\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Model parameters: We look for a line y = b1*x + b2.\n",
"b1 = Symbol('b1')\n",
"b2 = Symbol('b2')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Data points\n",
"data = [(1,14), (2, 13), (3, 12), (4, 10), (5,9), (7,8), (9,5)]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Function to minimize: S = 185*b1**2 + 62*b1*b2 - 524*b1 + 7*b2**2 - 142*b2 + 779\n"
]
}
],
"source": [
"# S is the function to minimize:\n",
"#\n",
"# For each data point the vertical error/residual is x*b1 + b2 - y. We want to\n",
"# minimize the sum of the squared residuals (least squares).\n",
"S = sum((p[0] * b1 + b2 - p[1]) ** 2 for p in data)\n",
"S = simplify(S)\n",
"print(\"Function to minimize: S = {}\".format(S))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"S is minimal for b1 = -367/334, b2 = 5013/334\n"
]
}
],
"source": [
"# Minimize S by setting its partial derivatives to zero.\n",
"d1 = diff(S, b1)\n",
"d2 = diff(S, b2)\n",
"solutions = solve([d1, d2], [b1, b2])\n",
"print(\"S is minimal for b1 = {}, b2 = {}\".format(solutions[b1], solutions[b2]))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitted line: y = -367*x/334 + 5013/334\n"
]
}
],
"source": [
"# Construct fitted line from the solutions\n",
"x = Symbol('x')\n",
"fitted_line = solutions[b1] * x + solutions[b2]\n",
"print(\"Fitted line: y = {}\".format(fitted_line))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Construct something we can plot with matplotlib\n",
"fitted_line_func = lambdify(x, fitted_line, modules=['numpy'])\n",
"x_plot = np.linspace(min(p[0] for p in data),\n",
" max(p[0] for p in data), 100)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAHd9JREFUeJzt3Xucz2X+//HHi1EhKrGtxUbrh2mcm+SQ0s5mYyPSpra2\nWt1otavUKlpFpRKtiq2kWKUiSsVGSienmAyRw2jbfjaJ1sipkpDr+8f1mVaiOXw+n7k+h+f9dnMz\n85nxeT9vbnp2zfW+3tdlzjlERCT5lQsdQEREYkOFLiKSIlToIiIpQoUuIpIiVOgiIilChS4ikiJU\n6CIiKUKFLiKSIlToIiIpIqMsL1a9enVXt27dsrykiEjSW7Zs2VbnXI2ivq9MC71u3brk5eWV5SVF\nRJKemX1cnO/TlIuISIpQoYuIpAgVuohIilChi4ikCBW6iEiKKLLQzewfZrbFzFYf5msDzMyZWfX4\nxBMRkeIqzgj9CeC8Q180szrAucCGGGc6rJ7jFtNz3OKyuJSISFIqstCdc/OBbYf50gPAzYDOsBMR\nSQClerDIzLoCnzrnVppZjCN9X+GoPHf9tu99PvWaNnG9rohIsinxTVEzqwQMBoYU8/v7mFmemeUV\nFBSU9HKec+T8Oxd0oLWIyBGZK0ZJmlld4GXnXGMzawK8AeyOfLk2sAlo5Zz77MfeJzs725Xq0f9/\n/hO6dmVNg5Zk/XMKNGhQ8vcQEUlSZrbMOZdd1PeVeITunFvlnPuJc66uc64usBFoWVSZR+U3v2Hc\nZQOp+8m/oGlTuPtu2Ls3bpcTEUlGxVm2OAVYDDQ0s41mdnX8Yx2iXDmuefpeKn/0L7jgArj1Vjjt\nNFisVS8iIoWKs8rlUudcTedcBedcbefchEO+Xtc5tzV+EQ9SsyZMneqnYHbuhHbt4E9/gl27yuTy\nIiKJLDmfFD3/fFizBq67DsaOhcxMeOml0KlERIJKzkIHqFIFHnwQliyB6tWhe3e48EL49NPQyURE\ngkjeQi/UqhXk5cG998Irr8Cpp/pR+4EDoZOJiJSp5C90gAoVYOBAWLUKTj8drr0W2rf30zIiImki\nNQq9UP36MHcuPPkkfPABtGgBQ4bAnj2hk4mIxF1qFTqAGVxxBeTnwyWXwLBh0KwZzJsXOpmISFyl\nXqEXqlEDJk2C116D/fuhQwfo3Ru2bw+dTEQkLlK30Aude66fW7/5Zpg4ERo1gmef1b4wIpJyUr/Q\nASpVghEj/GqYn/8cLr0UunSBjz8OnUxEJGbSo9ALNW/utwt44AF4+23IyvJr2b/9NnQyEZGopVeh\nA2RkQP/+fknjWWfBDTdA69awYkXoZCIiUUm/Qi908skwa5afT9+wAbKz/Vr23buL/rMiIgkofQsd\n/BLHnj39EserroKRI6FxY78yRkQkyaR3oReqVg3Gj4e33vJPnf761/D730MRJyzp4GoRSSQq9IN1\n6AArV8Jtt/ltejMz/Vp2LXEUkSRQrCPoYqXUR9CFsGYN9OkD77wDOTkwbhz84hfADw+uPqNeNUAH\nV4tIfMTtCLq0kZUFCxbAI4/A0qV+bn3ECNi3L3QyEZHD0gi9OD79FPr1gxdf9GeaPv44tGr13Uhd\nI3MRiSeN0GOpVi144QX/a+tWv269f3+O2fNV6GQiIt9RoZdE9+6wdq3fb33MGJ4c1YupPyub41RF\nRIqiQi+p446Dhx6CRYugalXo2hUuvhg++yx0MhFJcyr00mrTBpYvh7vugpkz/RLHxx/X0XciEowK\nPRpHHQWDB8P77/uNv/r08WvZ160LnUxE0pAKPRYaNIA334QJE2D1an9C0h13wDffhE4mImlEhR4r\nZtCrl98XpkcPuP12f6bpwoWhk4lImlChx9pJJ8HkyTB7tt+5sX17+OMfYceO0MlEJMWp0OOlUye/\nfcCNN/qbpaeeCtOna18YEYkbFXo8Va4Mo0bBu+/CT38KF13k17Jv3Bg6mYikIBV6WTjtNF/q993n\n91rPzIS//11H34lITKnQy0pGBgwY4Kdh2raF666Ddu1g1arQyUQkRajQy1q9ejBnDjz9NHz0EbRs\nCX/9K3z9dehkIpLkVOghmMFll/kHkC6/HIYP97s4vvlm6GQiksRU6CGdeCJMnAivv+5Xv+Tk+LXs\nn38eOpmIJCEVeiLIyfFz6bfcAk895W+aTp6sJY4iUiIq9ERRsSLcc4/f8KtePT8l06kTrF8fOpmI\nJIkiC93M/mFmW8xs9UGv3Wdm68zsfTN70cyOj2/MNNKkiT/HdMwYv0Vv48Z+Lfv+/cV+i57jFn93\nmpKIpI/ijNCfAM475LW5QGPnXFPgX8AtMc6V3sqX90ferV3rp2MGDIBWrWDZstDJRCSBFVnozrn5\nwLZDXnvNOVc4ZFwC1I5DNqlTB2bMgOeeg82bfan/5S/w1eGPviscmeeu30bu+m0aqYukmVjMofcC\nXjnSF82sj5nlmVleQUFBDC6XZsz8lgH5+dC7N9x/P2RlwStH/CsXkTRlrhgrKcysLvCyc67xIa8P\nBrKBC10x3ig7O9vl5eWVLql4Cxb4gzTWrYNLL4UHHvA7PB6kcFQ+9Zo2IRKKSIyZ2TLnXHZR31fq\nEbqZXQmcD1xWnDKXGGnfHlas8PutT5/ulzhOnKgljiJSukI3s/OAgUBX59zu2EaSIh19NAwd6os9\nK8s/jJSTAx9+CPiRuUbnIumnOMsWpwCLgYZmttHMrgYeAqoAc81shZk9GueccjiZmTBvHowb59ev\nN2kCd98Ne/eGTiYiARRrDj1WNIceR5s3w/XX+xUxjRv7QzVatw6dSkRiIO5z6JJgataEadNg5kx/\n3F3btvDnP8OuXaGTiUgZUaGnmi5d/ANJ/frBI4/4o+9mzAidSkTKgAo9FVWpAqNHw5IlfkfHbt2g\nRw/YtCl0MhGJIxV6KmvVCvLy/H7rs2f7m6hjx8KBA6GTiUgcqNBTXYUKMGiQ35739NPh2mv9WvY1\na0InE5EYU6Gni/r1Ye5ceOIJ+OADaNEChgyBPXtCJxORGFGhpxMzuPJKvy9Mz54wbBg0a+bXsotI\n0lOhp6MaNfzJSHPmwL590KGD3/hr+/bQyUQkCir0dPbrX8Pq1XDzzX4/mMxMmDpV+8KIJCkVerqr\nVAlGjPCrYerUgUsu8WvZN2wInUxESkiFLl7z5n7d+gMPwNtv+weSHnwQvv02dDIRKSYVuvxP+fLQ\nv79f0nj22XDDDX4/mBUrQicTkWJQocsPnXwyvPwyPPusn3rJzoaBA2G3dkoWSWQqdDk8M7+0MT8f\nrroKRo702/POnRs6mYgcgQpdfly1ajB+PLz1FmRkQMeOcMUVoPNhRRKOCl2Kp0MHWLkSbrvNT8Vk\nZsKkSVriKJJAVOhSfMccA3feCe+9Bw0b+qdOO3aEjz4KnUxEUKFLaWRlwYIFfr/1d9/1JySNGOGf\nOhWRYFToUjrlykHfvv4wjc6d/Y6O2dm+4EUkCBW6RKdWLZg+HV58EbZu9evWr78evvgidDKRtKNC\nl9jo1s2P1vv2hb//3U/LvPxy6FQiaUWFLrFz3HHw8MOwcCFUrer3hLn4Yti8OXQykbSgQpfYa9sW\nli+Hu+6CmTP9EsfHH9fRdyJxpkKX+DjqKBg8GN5/H1q2hD594JxzYN260MlEUpYKXeKrQQN44w2Y\nMMGfa9qsmV/L/s03oZOJpBwVusSfGfTq5feF6dEDhg71Z5ouXBg6mUhKUaFL2TnpJJg8GWbNgq++\ngvbt/aqYHTtCJxNJCSp0KXudO/s912+8ER57zB+mMX269oURiZIKXcI49lgYNQpyc/3I/aKL/Fr2\nTz4JnUwkaanQJazsbFi6FO67z++1fuqp8NBDOvpOpBRU6BJeRgYMGOCnYdq1g3794Mwz/aoYESk2\nFbokjnr14JVX4Jln/Ja8LVv6texffx06mUhSUKFLYjGD3/3OL3G8/HK45x5o2hTefDN0MpGEp0KX\nxHTiiTBxIrz+ul/9kpM
"text/plain": [
"<matplotlib.figure.Figure at 0x7fdb7f6c8b70>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot data points and fitted line\n",
"plt.scatter([p[0] for p in data], [p[1] for p in data], marker=\"+\")\n",
"plt.plot(x_plot, fitted_line_func(x_plot), 'r')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}