Simple tree-based interpeter

Requires 3.0b6 or beyond, which added some required functionality.

Overview 

This is an upgraded version of http://www.antlr.org/wiki/display/ANTLR3/Expression+evaluator that adds simple function definitions. Here is some sample input:

   fact(0) = 1
   fact(n) = fact(n-1)*n
   square(x)=x*x
   catalan(n)=fact(2*n)/square(fact(n))

   fact(10)
   catalan(10)

and here is the output using one of the main programs below with that input:

   3628800 (about 3*10^6)
   184756 (about 1*10^5)

The basic idea is that a parser reads the complete input, which is then evaluated by a tree grammar (instead of directly during parsing).
The main problem with this design is that it is necessary to re-traverse various sub-ASTs - namely the expressions in a function definition - more than once: Each evaluation of a function call requires such a traversal, and calls to the same function may occur multiple times in one expression.
Here are two solutions for this problem:

A) Starting a new tree parser for each traversal of an expression AST in a function definition.
B) Using a single tree parser, but changing the parsing position of the underlying tree node stream (this is a modification of Terence's original solution).

The example is a slightly extended version of an example written by Terence, whom I obviously give the credit of having invented the example.

Features 

The example now has the following features:

  • Expressions can use operators +, -, *, /, and % (the last one for modulus) as well as parentheses in the usual way.
  • There are three types of commands:
    a) function definitions
    b) variable definitions
    c) expression evaluations
  • Function definitions:
    • It is possible to define functions with a single parameter as follows:
         f(parameter) = expr
      
    • The parameter can either be a simple name, or it can be a fixed number. In the latter case, the value of the function is defined for this specific value. This allows the recursive definition of functions, e.g.:
         fact(0) = 1
         fact(n) = fact(n-1) * n
      
      or
         fib(0) = 0
         fib(1) = 1
         fib(x) = fib(x-1) + fib(x-2)
      
      or
         E(0)=1
         E(n)=10*E(n-1)
      
    • During function evaluation, the definitions are checked in the order specified until either a constant parameter matches; or a named parameter is encountered.
  • It is possible to define simple integer variables by writing var=expr, e.g.
       ae = 150*E(9)
       c = 300*E(3+3+3)
    
  • Directly writing an expression will evaluate it, e.g.
       1+1
       fact(30)
       fib(10)
    
    will print output
       2
       265252859812191058636308480000000
       55
    

Design

The design has the following peculiarities:

1A) In the implementation starting a new tree parser for each function evaluation, function definitions are stored during parsing (in a finally block). This allows forward calls, i.e. the following works:

    f(2)
    f(2)=4

and will output 4.

1B) On the other hand, the implementation using a single parser with resetting of the node stream stores the function definitions while tree parsing; the reason is that this implementation requires node indices for the function definitions, which can only be acquired during tree parsing (using input.index()). Forward calls are not possible, because definitions are stored during the same tree parser traversal as the evaluation:

    f(2)
    f(2)=4

outputs an error message.

2A) When a new tree parser is started for each function evaluation, no stack is necessary for the local variable (there is only one, namely the function's parameter). On the other hand, the shared state - consisting of a global memory for explicitly set variables; and the function definitions - must be passed explicitly into the newly created parsers.

2B) When a single tree parser is used, there is no need to pass the shared state to newly created tree parsers. On the other hand, explicit handling of local scopes is necessary; this could be done using an explicit stack, however, the current implementation simply overwrites variables in the localMemory map, and resets the value after evaluating the expression.

It is certainly possible to design the memory handling in other ways (e.g. a single stack for all variables).

3) All computations are done with BigInteger (so that e.g. the fact(30) can be computed).

4) However, function evaluation is done naively, i.e., each value is recomputed every time it is encountered in an expression. This means e.g. that for the Fibonacci numbers, an exponential number of function calls is executed, with corresponding horrible performance for larger values of n (say n > 30). It is of course possible to redesign this (by caching computed values), but this is nothing ANTLR-specific.

5) All the collections are generic ones, so that it is explicit which types are stored in them.

Implementation with a new tree parser per function evaluation

Input grammar

grammar Expr1;

options {
    output=AST;
    ASTLabelType=CommonTree;
}

tokens {
    // define pseudo-operations
    FUNC;
    CALL;
}

@members {
    /** List of function definitions. Must point at the FUNC nodes. */
    List<CommonTree> functionDefinitions = new ArrayList<CommonTree>();
}

// START:stat
prog: ( stat )*
    ;

stat:   expr NEWLINE                    -> expr
    |   ID '=' expr NEWLINE             -> ^('=' ID expr)
    |   func NEWLINE                    -> func
    |   NEWLINE                         -> // ignore
    ;

func:   ID  '(' formalPar ')' '=' expr  -> ^(FUNC ID formalPar expr)
    ;
	finally {
	  functionDefinitions.add($func.tree);
	}

formalPar
    :   ID
	|   INT
	;

// END:stat

// START:expr
expr:   multExpr (('+'^|'-'^) multExpr)*
    ;

multExpr
    :   atom (('*'|'/'|'%')^ atom)*
    ;

atom:   INT
    |   ID
    |   '(' expr ')'    -> expr
    |   ID '(' expr ')' -> ^(CALL ID expr)
    ;
// END:expr

// START:tokens
ID  :   ('a'..'z'|'A'..'Z')+
	;

INT :   '0'..'9'+
    ;

NEWLINE
    :	'\r'? '\n'
    ;

WS  :   (' '|'\t')+ { skip(); }
    ;
// END:tokens

Tree grammar

tree grammar Eval1;

options {
    tokenVocab=Expr1;
    ASTLabelType=CommonTree;
}

// START:members
@header {
    import java.util.Map;
    import java.util.HashMap;
    import java.math.BigInteger;
}

@members {
    /** Points to functions tracked by tree builder. */
    private List<CommonTree> functionDefinitions;

    /** Remember local variables. Currently, this is only the function parameter.
     */
    private final Map<String, BigInteger> localMemory = new HashMap<String, BigInteger>();

    /** Remember global variables set by =. */
    private Map<String, BigInteger> globalMemory = new HashMap<String, BigInteger>();

    /** Set up an evaluator with a node stream; and a set of function definition ASTs. */
    public Eval1(CommonTreeNodeStream nodes, List<CommonTree> functionDefinitions) {
        this(nodes);
        this.functionDefinitions = functionDefinitions;
    }

    /** Set up a local evaluator for a nested function call. The evaluator gets the definition
     *  tree of the function; the set of all defined functions (to find locally called ones); a
     *  pointer to the global variable memory; and the value of the function parameter to be
     *  added to the local memory.
     */
    private Eval1(CommonTree function,
                 List<CommonTree> functionDefinitions,
                 Map<String, BigInteger> globalMemory,
                 BigInteger paramValue) {
        // Expected tree for function: ^(FUNC ID ( INT | ID ) expr)
        this(new CommonTreeNodeStream(function.getChild(2)), functionDefinitions);
        this.globalMemory = globalMemory;
        localMemory.put(function.getChild(1).getText(), paramValue);
    }

    /** Find matching function definition for a function name and parameter
     *  value. The first definition is returned where (a) the name matches
     *  and (b) the formal parameter agrees if it is defined as constant.
     */
    private CommonTree findFunction(String name, BigInteger paramValue) {
        SEARCH:
        for (CommonTree f : functionDefinitions) {
            // Expected tree for f: ^(FUNC ID (ID | INT) expr)
            if (f.getChild(0).getText().equals(name)) {
                // Check whether parameter matches
              	CommonTree formalPar = (CommonTree) f.getChild(1);
                if (formalPar.getToken().getType() == INT
                    && !new BigInteger(formalPar.getToken().getText()).equals(paramValue)) {
                        // Constant in formalPar list does not match actual value -> no match.
                        continue SEARCH;
                }
                // Parameter (value for INT formal arg) as well as fct name agrees!
                return f;
            }
        }
        return null;
    }

    /** Get value of name up call stack. */
    public BigInteger getValue(String name) {
        BigInteger value = localMemory.get(name);
        if ( value!=null ) {
            return value;
        }
        value = globalMemory.get(name);
        if ( value!=null ) {
            return value;
        }
        // not found in local memory or global memory
        System.err.println("undefined variable "+name);
        return BigInteger.ZERO;
    }
}
// END:members

// START:rules
prog:   stat*
    ;

stat:   expr                       { String result = $expr.value.toString();
                                     System.out.println($expr.value + " (about " + result.charAt(0) + "*10^" + (result.length()-1) + ")");
                                   }
    |   ^('=' ID expr)             { globalMemory.put($ID.text, $expr.value); }
    |   ^(FUNC .+)	               // ignore FUNCs - we added them to functionDefinitions already in parser.
    ;

expr returns [BigInteger value]
    :   ^('+' a=expr b=expr)       { $value = $a.value.add($b.value); }
    |   ^('-' a=expr b=expr)       { $value = $a.value.subtract($b.value); }
    |   ^('*' a=expr b=expr)       { $value = $a.value.multiply($b.value); }
    |   ^('/' a=expr b=expr)       { $value = $a.value.divide($b.value); }
    |   ^('%' a=expr b=expr)       { $value = $a.value.remainder($b.value); }
    |   ID                         { $value = getValue($ID.text); }
    |   INT                        { $value = new BigInteger($INT.text); }
    |   call                       { $value = $call.value; }
    ;

call returns [BigInteger value]
    :   ^(CALL ID expr)            { BigInteger p = $expr.value;
                                     CommonTree funcRoot = findFunction($ID.text, p);
                                     if (funcRoot == null) {
                                         System.err.println("No match found for " + $ID.text + "(" + p + ")");
                                     } else {
                                         // Here we set up the local evaluator to run over the
                                         // function definition with the parameter value.
                                         // This re-reads a sub-AST of our input AST!
                                         Eval1 e = new Eval1(funcRoot, functionDefinitions, globalMemory, p);
                                         $value = e.expr();
                                     }
                                   }
    ;
// END:rules

Main program

import org.antlr.runtime.*;
import org.antlr.runtime.tree.*;

public class Test1 {
    public static void main(String[] args) throws Exception {
        ANTLRInputStream input = new ANTLRInputStream(System.in);
        Expr1Lexer lexer = new Expr1Lexer(input);
        CommonTokenStream tokens = new CommonTokenStream(lexer);
        Expr1Parser parser = new Expr1Parser(tokens);
        CommonTree t  = (CommonTree) parser.prog().getTree();

        CommonTreeNodeStream nodes = new CommonTreeNodeStream(t);
        Eval1 evaluator = new Eval1(nodes, parser.functionDefinitions);
        evaluator.prog();
    }
}

Implementation with a single tree parser and node stream resetting

Input grammar

grammar Expr2;

options {
    output=AST;
    ASTLabelType=CommonTree;
}

tokens {
    // define pseudo-operations
    FUNC;
    CALL;
}

// START:stat
prog: ( stat )*
    ;

stat:   expr NEWLINE                    -> expr
    |   ID '=' expr NEWLINE             -> ^('=' ID expr)
    |   func NEWLINE                    -> func
    |   NEWLINE                         -> // ignore
    ;

func:   ID  '(' formalPar ')' '=' expr  -> ^(FUNC ID formalPar expr)
    ;

formalPar
    :   ID
	|   INT
	;

// END:stat

// START:expr
expr:   multExpr (('+'^|'-'^) multExpr)*
    ;

multExpr
    :   atom (('*'|'/'|'%')^ atom)*
    ;

atom:   INT
    |   ID
    |   '(' expr ')'    -> expr
    |   ID '(' expr ')' -> ^(CALL ID expr)
    ;
// END:expr

// START:tokens
ID  :   ('a'..'z'|'A'..'Z')+
	;

INT :   '0'..'9'+
    ;

NEWLINE
    :	'\r'? '\n'
    ;

WS  :   (' '|'\t')+ { skip(); }
    ;
// END:tokens

Tree grammar

tree grammar Eval2;

options {
    tokenVocab=Expr2;
    ASTLabelType=CommonTree;
}

// START:members
@header {
    import java.util.Map;
    import java.util.HashMap;
    import java.util.LinkedHashMap;
    import java.math.BigInteger;
}

@members {
    /** Functions definitions and their start in the tokenstream (for pushing them in nested calls).
     *  It's important to keep the order of the definitions, therefore we use a LinkedHashMap here.
     */
    private Map<CommonTree, Integer> functionDefinitions = new LinkedHashMap<CommonTree, Integer>();

    /** Remember local variables. Currently, this is only the function parameter.
     */
    private final Map<String, BigInteger> localMemory = new HashMap<String, BigInteger>();

    /** Remember global variables set by =. */
    private Map<String, BigInteger> globalMemory = new HashMap<String, BigInteger>();

    /** Find matching function definition for a function name and parameter
     *  value. The first definition is returned where (a) the name matches
     *  and (b) the formal parameter agrees if it is defined as constant.
     */
    private CommonTree findFunction(String name, BigInteger paramValue) {
        SEARCH:
        for (CommonTree f : functionDefinitions.keySet()) {
            // Expected tree for f: ^(FUNC ID (ID | INT) expr)
            if (f.getChild(0).getText().equals(name)) {
                // Check whether parameter matches
              	CommonTree formalPar = (CommonTree) f.getChild(1);
                if (formalPar.getToken().getType() == INT
                    && !new BigInteger(formalPar.getToken().getText()).equals(paramValue)) {
                        // Constant in formalPar list does not match actual value -> no match.
                        continue SEARCH;
                }
                // Parameter (value for INT formal arg) as well as fct name agrees!
                return f;
            }
        }
        return null;
    }

    /** Get value of name up call stack. */
    public BigInteger getValue(String name) {
        BigInteger value = localMemory.get(name);
        if ( value!=null ) {
            return value;
        }
        value = globalMemory.get(name);
        if ( value!=null ) {
            return value;
        }
        // not found in local memory or global memory
        System.err.println("undefined variable "+name);
        return BigInteger.ZERO;
    }
}
// END:members

// START:rules
prog
    :   stat+
    ;

stat:   expr                       { String result = $expr.value.toString();
                                     System.out.println($expr.value + " (about " + result.charAt(0) + "*10^" + (result.length()-1) + ")");
                                   }
    |   ^('=' ID expr)             { globalMemory.put($ID.text, $expr.value); }
    |   ^(FUNC ID . /*ID|INT*/     { functionDefinitions.put($stat.start, input.index()); }
                    . /* expr */)
    ;

expr returns [BigInteger value]
    :   ^('+' a=expr b=expr)       { $value = $a.value.add($b.value); }
    |   ^('-' a=expr b=expr)       { $value = $a.value.subtract($b.value); }
    |   ^('*' a=expr b=expr)       { $value = $a.value.multiply($b.value); }
    |   ^('/' a=expr b=expr)       { $value = $a.value.divide($b.value); }
    |   ^('%' a=expr b=expr)       { $value = $a.value.remainder($b.value); }
    |   ID                         { $value = getValue($ID.text); }
    |   INT                        { $value = new BigInteger($INT.text); }
    |   call                       { $value = $call.value; }
    ;

call returns [BigInteger value]
    :   ^(CALL ID expr)            { BigInteger p = $expr.value;
                                     CommonTree funcRoot = findFunction($ID.text, p);
                                     if (funcRoot == null) {
                                         System.err.println("No match found for " + $ID.text + "(" + p + ")");
                                     } else {
                                         // Push parameter value into local memory; and expr tree onto node stream.
                                         String paramName = funcRoot.getChild(1).getText();
                                         BigInteger prevValue = localMemory.put(paramName, p);

                                         CommonTreeNodeStream commonInput = (CommonTreeNodeStream) input;
                                         int exprStartIndex = functionDefinitions.get(funcRoot);
                                         commonInput.push(exprStartIndex);

                                         value = expr();

                                         // Restore node stream and local variable to previous values.
                                         commonInput.pop();
                                         localMemory.put(paramName, prevValue);
                                     }
                                   }
    ;
// END:rules

Main program

import org.antlr.runtime.*;
import org.antlr.runtime.tree.*;

public class Test2 {
    public static void main(String[] args) throws Exception {
        ANTLRInputStream input = new ANTLRInputStream(System.in);
        Expr2Lexer lexer = new Expr2Lexer(input);
        CommonTokenStream tokens = new CommonTokenStream(lexer);
        Expr2Parser parser = new Expr2Parser(tokens);
        CommonTree t  = (CommonTree) parser.prog().getTree();

        CommonTreeNodeStream nodes = new CommonTreeNodeStream(t);
        Eval2 evaluator = new Eval2(nodes);
        evaluator.prog();
    }
}