{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Traversing Expression Trees" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import pymbolic.primitives as p\n", "x = p.Variable(\"x\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Power(Sum((Variable('x'), 3)), 5)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u = (x+3)**5\n", "u" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Traversal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many options to walk this expression.\n", "\n", "* One big recursive function with many `if isinstance` checks\n", "* \"Visitor pattern\" -> Define a class, dispatch to a different method for each node type" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'map_sum'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.Sum.mapper_method" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from pymbolic.mapper import WalkMapper\n", "\n", "class MyMapper(WalkMapper):\n", " def map_sum(self, expr):\n", " print(\"sum\", expr.children)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Power(Sum((Variable('x'), 3)), 5)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u = (x+3)**5\n", "u" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sum (Variable('x'), 3)\n" ] } ], "source": [ "mymapper = MyMapper()\n", "mymapper(u)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Recursive Traversal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What if there is another sum nested inside our existing one?" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Sum((Power(Sum((Variable('x'), 3)), 5), 5))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u = (x+3)**5 + 5\n", "u" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sum (Power(Sum((Variable('x'), 3)), 5), 5)\n" ] } ], "source": [ "mymapper(u)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What do you notice? Is something missing?\n", "\n", "Improve implementation as `MyMapper2`:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from pymbolic.mapper import WalkMapper\n", "\n", "class MyMapper2(WalkMapper):\n", " def map_sum(self, expr):\n", " print(\"sum\", expr.children)\n", "\n", " for ch in expr.children:\n", " self.rec(ch)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sum (Power(Sum((Variable('x'), 3)), 5), 5)\n", "sum (Variable('x'), 3)\n" ] } ], "source": [ "mymapper2 = MyMapper2()\n", "mymapper2(u)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Mapper Inheritance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Above: What about `map_variable`? `map_power`?\n", "* Mappers inherit all non-overridden behavior from their superclasses.\n", "\n", "This makes it easy to *inherit a base behavior* and then selectively change a few pieces." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Mappers with Values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Mappers do more than just *traverse*\n", "* They can also return a value\n", " * What type? Any desired one.\n", " \n", "For example: Could return a string." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from pymbolic.mapper import RecursiveMapper\n", "class MyStringifier(RecursiveMapper):\n", " def map_sum(self, expr):\n", " return \"+\".join(self.rec(ch) for ch in expr.children)\n", " \n", " def map_product(self, expr):\n", " return \"*\".join(self.rec(ch) for ch in expr.children)\n", " \n", " def map_variable(self, expr):\n", " return expr.name\n", " \n", " def map_constant(self, expr):\n", " return str(expr)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'x*5+x*7'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u = (x * 5)+(x * 7)\n", "\n", "\n", "\n", "mystrifier = MyStringifier()\n", "\n", "mystrifier(u)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mappers can also return another expression. `IdentityMapper` is a base that returns an identical (deep) copy of an expression:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "False\n" ] } ], "source": [ "from pymbolic.mapper import IdentityMapper\n", "\n", "idmap = IdentityMapper()\n", "u2 = idmap(u)\n", "print(u2 == u)\n", "print(u2 is u)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Term Rewriting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`IdentityMapper` can be used as a convenient base for term rewriting.\n", "\n", "As a very simple example, let us\n", "\n", "* Change the name of all variables by appending a prime\n", "* Change all products to sums" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MyIdentityMapper(IdentityMapper):\n", " def map_variable(self, expr):\n", " return p.Variable(expr.name + \"'\")\n", "\n", " def map_product(self, expr):\n", " return p.Sum(tuple(self.rec(ch) for ch in expr.children))\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x' + 3 + (x' + 17)**3\n" ] } ], "source": [ "u = (x*3)*(x+17)**3\n", "\n", "myidmap = MyIdentityMapper()\n", "print(myidmap(u))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.0+" } }, "nbformat": 4, "nbformat_minor": 0 }