{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "

Multivariate least squares" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's do a simple multivariate fit: a fit of a plane to some data. \n", "Write a function for the model:\n", "$$ z = a + bx + cy$$\n", "and for the derivatives with respect to the parameters.\n", "Then simulate a data set\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def model(x,y,par) :\n", " \"\"\" Function of two variables, linear in parameters\n", " \"\"\"\n", " return # model here\n", "\n", "def deriv(x,y) :\n", " \"\"\" Derivative with respect to each parameter\n", " \"\"\"\n", " return # return derivatives with respect to each parameter\n", " \n", "# simulate some data \n", "x= # x independent variable\n", "y= # y independent variable\n", "truepar= # parameter values\n", "sig= # uncertainties on data values\n", "z=model(x,y,truepar)+np.random.normal(0.,sig,size=len(x))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the simple form for least squares:\n", "$${A^T\\over \\sigma} {A\\over \\sigma} par = {A^T\\over \\sigma} {y\\over \\sigma}$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# design matrix\n", "design=\n", "\n", "# solve for best fitting parameters in least squares sense\n", "data=\n", "ATA=\n", "par=\n", "print('par: ',par)\n", "#use inverse matrix to get uncertainties\n", "inv=np.linalg.inv(ATA)\n", "print(np.dot(inv,np.dot(design.T,data)))\n", "print(inv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some possibilities for visualizing the results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# for plotting in 3D\n", "from mpl_toolkits.mplot3d import Axes3D\n", "plt3d = plt.figure().gca(projection='3d')\n", "ax = plt.gca()\n", "ax.scatter(x, y, z)\n", "\n", "# plot the solution\n", "xx, yy = np.meshgrid(np.arange(0,1,0.1),np.arange(0,1,0.1))\n", "zfit= model(xx,yy,par)\n", "plt3d.plot_surface(xx, yy, zfit, alpha=0.8)\n", "plt3d.plot_surface(xx,yy,model(xx,yy,truepar),color='g',alpha=0.8)\n", "\n", "# 2D representation using color as dependent variable\n", "plt.figure()\n", "plt.scatter(x,y,c=z)\n", "plt.figure()\n", "xx,yy=np.mgrid[0:1:0.1,0:1:0.1]\n", "plt.imshow(model(xx,yy,par),origin='bottom')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 1 }