{"nbformat": 3, "nbformat_minor": 0, "worksheets": [{"metadata": {}, "cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Solving Least-Squares Problems"]}, {"prompt_number": 24, "input": ["import numpy as np\n", "import numpy.linalg as la\n", "import scipy.linalg as spla"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 25, "input": ["m = 6\n", "n = 4\n", "\n", "A = np.random.randn(m, n)\n", "b = np.random.randn(m)"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's try solving that as a linear system using `la.solve`:"]}, {"prompt_number": 26, "input": ["la.solve(A, b)"], "metadata": {}, "outputs": [{"evalue": "Last 2 dimensions of the array must be square", "traceback": ["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mLinAlgError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msolve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/lib/python3/dist-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36msolve\u001b[0;34m(a, b)\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_makearray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0m_assertRankAtLeast2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0m_assertNdSquareness\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrap\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_makearray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_t\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_commonType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/lib/python3/dist-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36m_assertNdSquareness\u001b[0;34m(*arrays)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mLinAlgError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Last 2 dimensions of the array must be square'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_assertFinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mLinAlgError\u001b[0m: Last 2 dimensions of the array must be square"], "output_type": "pyerr", "ename": "LinAlgError"}], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["OK, let's do QR-based least-squares then."]}, {"prompt_number": 27, "input": ["Q, R = la.qr(A)"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["What did we get? Full QR or reduced QR?"]}, {"prompt_number": 28, "input": ["Q.shape"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 28, "metadata": {}, "text": ["(6, 4)"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 29, "input": ["R.shape"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 29, "metadata": {}, "text": ["(4, 4)"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["Is that a problem?\n", "* Do we really need the bottom part of $R$? (A bunch of zeros)\n", "* Do we really need the far right part of $Q$? (=the bottom part of $Q^T$)\n", "\n", "-----------------\n", "OK, so find the minimizing $x$:"]}, {"prompt_number": 39, "input": ["x = spla.solve_triangular(R, Q.T.dot(b), lower=False)"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["We predicted that $\\|Ax-b\\|_2$ would be the same as $\\|Rx-Q^Tb\\|_2$:"]}, {"prompt_number": 45, "input": ["la.norm(A.dot(x)-b, 2)"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 45, "metadata": {}, "text": ["1.4448079009090737"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 47, "input": ["la.norm(R.dot(x) - Q.T.dot(b))"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 47, "metadata": {}, "text": ["1.5700924586837752e-16"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["--------------\n", "Heh--*reduced* QR left out the right half of Q. Let's try again with complete QR:"]}, {"prompt_number": 59, "input": ["Q2, R2 = la.qr(A, mode=\"complete\")"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 60, "input": ["x2 = spla.solve_triangular(R[:n], Q.T[:n].dot(b), lower=False)"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 63, "input": ["la.norm(A.dot(x)-b, 2)"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 63, "metadata": {}, "text": ["1.4448079009090737"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 64, "input": ["la.norm(R2.dot(x2) - Q2.T.dot(b))"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 64, "metadata": {}, "text": ["1.444807900909074"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["Did we get the same `x` both times?"]}, {"prompt_number": 69, "input": ["x - x2"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 69, "metadata": {}, "text": ["array([ 0., 0., 0., 0.])"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"cell_type": "markdown", "metadata": {}, "source": ["Finally, let's compare against the normal equations:"]}, {"prompt_number": 70, "input": ["x3 = la.solve(A.T.dot(A), A.T.dot(b))"], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}, {"prompt_number": 71, "input": ["x3 - x"], "metadata": {}, "outputs": [{"output_type": "pyout", "prompt_number": 71, "metadata": {}, "text": ["array([ 4.99600361e-16, -1.66533454e-16, 7.77156117e-16,\n", " -1.11022302e-16])"]}], "collapsed": false, "cell_type": "code", "language": "python"}, {"input": [], "metadata": {}, "outputs": [], "collapsed": false, "cell_type": "code", "language": "python"}]}], "metadata": {"signature": "sha256:aff605c43dd119431c7501d23308c2a8cec7a71ea46b21a799cec6cb5179af1a", "name": ""}}