Partial application and piping with AST transformation



In the previous article I wrote about how to implement partial application and piping using operator overloading and decorators. But we can use a bit different approach – AST transformation.

For example we have code:

def add ( x , y ): return x + y addFive = add (..., 5 ) print ( addFive ( 10 ))

We can look to AST of this code using ast module from standard library and dump function from gist:

import ast code = open ( 'src.py' ) # the previous code tree = ast . parse ( code ) print ( dump ( tree ))

It would be like:

Module ( body = [ FunctionDef ( name = 'add' , args = arguments ( args = [ arg ( arg = 'x' , annotation = None ), arg ( arg = 'y' , annotation = None ), ], vararg = None , kwonlyargs = [], kw_defaults = [], kwarg = None , defaults = []), body = [ Return ( value = BinOp ( left = Name ( id = 'x' , ctx = Load ()), op = Add (), right = Name ( id = 'y' , ctx = Load ()))), ], decorator_list = [], returns = None ), Assign ( targets = [ Name ( id = 'addFive' , ctx = Store ()), ], value = Call ( func = Name ( id = 'add' , ctx = Load ()), args = [ Ellipsis (), Num ( n = 5 ), ], keywords = [])), Expr ( value = Call ( func = Name ( id = 'print' , ctx = Load ()), args = [ Call ( func = Name ( id = 'addFive' , ctx = Load ()), args = [ Num ( n = 10 ), ], keywords = []), ], keywords = [])), ])

And we can easily spot call with ellipsis argument:

Call ( func = Name ( id = 'add' , ctx = Load ()), args = [ Ellipsis (), Num ( n = 5 ), ], keywords = [])

We need to wrap each call with ellipsis in lambda and replace ... with the lambda’s argument. We can do it with ast.NodeTransformer . It calls visit_Call method for each Call node:

class EllipsisPartialTransform ( ast . NodeTransformer ): def __init__ ( self ): self . _counter = 0 def _get_arg_name ( self ): """Return unique argument name for lambda.""" try : return '__ellipsis_partial_arg_{}' . format ( self . _counter ) finally : self . _counter += 1 def _is_ellipsis ( self , arg ): return isinstance ( arg , ast . Ellipsis ) def _replace_argument ( self , node , arg_name ): """Replace ellipsis with argument.""" replacement = ast . Name ( id = arg_name , ctx = ast . Load ()) node . args = [ replacement if self . _is_ellipsis ( arg ) else arg for arg in node . args ] return node def _wrap_in_lambda ( self , node ): """Wrap call in lambda and replace ellipsis with argument.""" arg_name = self . _get_arg_name () node = self . _replace_argument ( node , arg_name ) return ast . Lambda ( args = ast . arguments ( args = [ ast . arg ( arg = arg_name , annotation = None )], vararg = None , kwonlyargs = [], kw_defaults = [], kwarg = None , defaults = []), body = node ) def visit_Call ( self , node ): if any ( self . _is_ellipsis ( arg ) for arg in node . args ): node = self . _wrap_in_lambda ( node ) node = ast . fix_missing_locations ( node ) return self . generic_visit ( node )

So now we can transform AST with visit method and dump result:

tree = EllipsisPartialTransform (). visit ( tree ) print ( dump ( tree ))

And you can see changes:

Module ( body = [ FunctionDef ( name = 'add' , args = arguments ( args = [ arg ( arg = 'x' , annotation = None ), arg ( arg = 'y' , annotation = None ), ], vararg = None , kwonlyargs = [], kw_defaults = [], kwarg = None , defaults = []), body = [ Return ( value = BinOp ( left = Name ( id = 'x' , ctx = Load ()), op = Add (), right = Name ( id = 'y' , ctx = Load ()))), ], decorator_list = [], returns = None ), Assign ( targets = [ Name ( id = 'addFive' , ctx = Store ()), ], value = Lambda ( args = arguments ( args = [ arg ( arg = '__ellipsis_partial_arg_0' , annotation = None ), ], vararg = None , kwonlyargs = [], kw_defaults = [], kwarg = None , defaults = []), body = Call ( func = Name ( id = 'add' , ctx = Load ()), args = [ Num ( n = 5 ), Name ( id = '__ellipsis_partial_arg_0' , ctx = Load ()), ], keywords = []))), Expr ( value = Call ( func = Name ( id = 'print' , ctx = Load ()), args = [ Call ( func = Name ( id = 'addFive' , ctx = Load ()), args = [ Num ( n = 10 ), ], keywords = []), ], keywords = [])), ])

AST is not easy to read, so we can use astunparse for transforming it to source code:

from astunparse import unparse print ( unparse ( tree ))

Result is a bit ugly, but more readable than AST:

def add ( x , y ): return ( x + y ) addFive = ( lambda __ellipsis_partial_arg_0 : add ( 5 , __ellipsis_partial_arg_0 )) print ( addFive ( 10 ))

For testing result we can compile AST and run it:

exec ( compile ( tree , '<string>' , 'exec' )) # 15

And it’s working! Back to piping, for example we have code:

"hello world" @ str . upper @ print

It’s AST would be:

Module ( body = [ Expr ( value = BinOp ( left = BinOp ( left = Str ( s = 'hello world' ), op = MatMult (), right = Attribute ( value = Name ( id = 'str' , ctx = Load ()), attr = 'upper' , ctx = Load ())), op = MatMult (), right = Name ( id = 'print' , ctx = Load ()))), ])

BinOp with op=MatMult() is place where we use matrix multiplication operator. We need to transform it to call of right part with left part as an argument:

class MatMulPipeTransformation ( ast . NodeTransformer ): def _replace_with_call ( self , node ): """Call right part of operation with left part as an argument.""" return ast . Call ( func = node . right , args = [ node . left ], keywords = []) def visit_BinOp ( self , node ): if isinstance ( node . op , ast . MatMult ): node = self . _replace_with_call ( node ) node = ast . fix_missing_locations ( node ) return self . generic_visit ( node )

Transformed AST would be:

Module ( body = [ Expr ( value = Call ( func = Name ( id = 'print' , ctx = Load ()), args = [ Call ( func = Attribute ( value = Name ( id = 'str' , ctx = Load ()), attr = 'upper' , ctx = Load ()), args = [ Str ( s = 'hello world' ), ], keywords = []), ], keywords = [])), ])

And result code is just a nested calls:

print ( str . upper ( 'hello world' )) # HELLO WORLD

So now it’s time to combine both transformers. For example we have code:

from functools import reduce import operator range ( 100 ) @ filter ( lambda x : x % 2 == 0 , ...) \ @ map ( lambda x : x ** 2 , ...) \ @ zip (..., range ( 200 , 250 )) \ @ map ( sum , ...) \ @ reduce ( operator . add , ...) \ @ str . format ( 'result: {}' , ...) \ @ str . upper \ @ print

We can transform and run it with:

code = open ( 'src.py' ) # the previous code tree = ast . parse ( code ) tree = MatMulPipeTransformation (). visit ( EllipsisPartialTransform (). visit ( tree )) exec ( compile ( tree , '<string>' , 'exec' ))

It’s working, output as expected is:

RESULT : 172925

However result code is a bit messy:

from functools import reduce import operator print ( str . upper (( lambda __ellipsis_partial_arg_5 : str . format ( 'result: {}' , __ellipsis_partial_arg_5 ))( ( lambda __ellipsis_partial_arg_4 : reduce ( operator . add , __ellipsis_partial_arg_4 ))( ( lambda __ellipsis_partial_arg_3 : map ( sum , __ellipsis_partial_arg_3 ))( ( lambda __ellipsis_partial_arg_2 : zip ( __ellipsis_partial_arg_2 , range ( 200 , 250 )))( ( lambda __ellipsis_partial_arg_1 : map (( lambda x : ( x ** 2 )), __ellipsis_partial_arg_1 ))( ( lambda __ellipsis_partial_arg_0 : filter (( lambda x : (( x % 2 ) == 0 )), __ellipsis_partial_arg_0 ))( range ( 100 )))))))))

This approach is better then previous, we don’t need to manually wrap all functions with ellipsis_partial or use _ helper. Also we don’t use custom Partial . But with this approach we need to manually transform AST, so in the next part I’ll show how we can do it automatically with module finder/loader.

Gist with sources, previous part, next part.