# coding: utf-8 # # Traversing Expression Trees # In[2]: import pymbolic.primitives as p x = p.Variable("x") # In[3]: u = (x+3)**5 u # ## Traversal # Many options to walk this expression. # # * One big recursive function with many `if isinstance` checks # * "Visitor pattern" -> Define a class, dispatch to a different method for each node type # In[4]: p.Sum.mapper_method # In[5]: from pymbolic.mapper import WalkMapper class MyMapper(WalkMapper): def map_sum(self, expr): print("sum", expr.children) # In[6]: u = (x+3)**5 u # In[7]: mymapper = MyMapper() mymapper(u) # # Recursive Traversal # What if there is another sum nested inside our existing one? # In[8]: u = (x+3)**5 + 5 u # In[9]: mymapper(u) # What do you notice? Is something missing? # # Improve implementation as `MyMapper2`: # In[10]: from pymbolic.mapper import WalkMapper class MyMapper2(WalkMapper): def map_sum(self, expr): print("sum", expr.children) for ch in expr.children: self.rec(ch) # In[11]: mymapper2 = MyMapper2() mymapper2(u) # ## Mapper Inheritance # * Above: What about `map_variable`? `map_power`? # * Mappers inherit all non-overridden behavior from their superclasses. # # This makes it easy to *inherit a base behavior* and then selectively change a few pieces. # ## Mappers with Values # * Mappers do more than just *traverse* # * They can also return a value # * What type? Any desired one. # # For example: Could return a string. # In[12]: from pymbolic.mapper import RecursiveMapper class MyStringifier(RecursiveMapper): def map_sum(self, expr): return "+".join(self.rec(ch) for ch in expr.children) def map_product(self, expr): return "*".join(self.rec(ch) for ch in expr.children) def map_variable(self, expr): return expr.name def map_constant(self, expr): return str(expr) # In[13]: u = (x * 5)+(x * 7) mystrifier = MyStringifier() mystrifier(u) # Mappers can also return another expression. `IdentityMapper` is a base that returns an identical (deep) copy of an expression: # In[14]: from pymbolic.mapper import IdentityMapper idmap = IdentityMapper() u2 = idmap(u) print(u2 == u) print(u2 is u) # ## Term Rewriting # `IdentityMapper` can be used as a convenient base for term rewriting. # # As a very simple example, let us # # * Change the name of all variables by appending a prime # * Change all products to sums # In[15]: class MyIdentityMapper(IdentityMapper): def map_variable(self, expr): return p.Variable(expr.name + "'") def map_product(self, expr): return p.Sum(tuple(self.rec(ch) for ch in expr.children)) # In[16]: u = (x*3)*(x+17)**3 myidmap = MyIdentityMapper() print(myidmap(u)) # In[ ]: