Source code for petab.v1.math.SympyVisitor

"""PEtab-math to sympy conversion."""

import sympy as sp
from sympy.logic.boolalg import Boolean, BooleanFalse, BooleanTrue

from ._generated.PetabMathExprParser import PetabMathExprParser
from ._generated.PetabMathExprParserVisitor import PetabMathExprParserVisitor

__all__ = ["MathVisitorSympy"]

# Mappings of PEtab math functions to sympy functions

# trigonometric functions
_trig_funcs = {
    "sin": sp.sin,
    "cos": sp.cos,
    "tan": sp.tan,
    "sec": sp.sec,
    "csc": sp.csc,
    "cot": sp.cot,
    "sinh": sp.sinh,
    "cosh": sp.cosh,
    "tanh": sp.tanh,
    "sech": sp.sech,
    "csch": sp.csch,
    "coth": sp.coth,
    "arccos": sp.acos,
    "arcsin": sp.asin,
    "arctan": sp.atan,
    "arcsec": sp.asec,
    "arccsc": sp.acsc,
    "arccot": sp.acot,
    "arcsinh": sp.asinh,
    "arccosh": sp.acosh,
    "arctanh": sp.atanh,
    "arcsech": sp.asech,
    "arccsch": sp.acsch,
    "arccoth": sp.acoth,
}
_unary_funcs = {
    "exp": sp.exp,
    "log10": lambda x, evaluate=True: -sp.oo
    if x.is_zero is True
    else sp.log(x, 10, evaluate=evaluate),
    "log2": lambda x, evaluate=True: -sp.oo
    if x.is_zero is True
    else sp.log(x, 2, evaluate=evaluate),
    "ln": sp.log,
    "sqrt": sp.sqrt,
    "abs": sp.Abs,
    "sign": sp.sign,
}
_binary_funcs = {
    "pow": sp.Pow,
    "min": sp.Min,
    "max": sp.Max,
}

# reserved names that cannot be used as variable names
_reserved_names = {
    "inf",
    "nan",
    "true",
    "false",
}


[docs] class MathVisitorSympy(PetabMathExprParserVisitor): """ ANTLR4 visitor for PEtab-math-to-sympy conversion. Visitor for PEtab math expression AST generated using ANTLR4. Converts PEtab math expressions to sympy expressions. Most users will not need to interact with this class directly, but rather use :func:`petab.math.sympify_petab`. Evaluation of any sub-expressions currently relies on sympy's defaults. For a general introduction to ANTLR4 visitors, see: https://github.com/antlr/antlr4/blob/7d4cea92bc3f7d709f09c3f1ac77c5bbc71a6749/doc/python-target.md :param evaluate: Whether to evaluate the expression. """ def __init__(self, evaluate=True): super().__init__() self.evaluate = evaluate
[docs] def visitPetabExpression( self, ctx: PetabMathExprParser.PetabExpressionContext ) -> sp.Expr | sp.Basic: """Visit the root of the expression tree.""" return self.visit(ctx.getChild(0))
[docs] def visitNumber(self, ctx: PetabMathExprParser.NumberContext) -> sp.Float: """Convert number to sympy Float.""" return sp.Float(ctx.getText())
[docs] def visitVar(self, ctx: PetabMathExprParser.VarContext) -> sp.Symbol: """Convert identifier to sympy Symbol.""" if ctx.getText().lower() in _reserved_names: raise ValueError(f"Use of reserved name {ctx.getText()!r}") return sp.Symbol(ctx.getText(), real=True)
[docs] def visitMultExpr( self, ctx: PetabMathExprParser.MultExprContext ) -> sp.Expr: """Convert multiplication and division expressions to sympy.""" if ctx.getChildCount() == 3: operand1 = bool2num(self.visit(ctx.getChild(0))) operand2 = bool2num(self.visit(ctx.getChild(2))) if ctx.ASTERISK(): return sp.Mul(operand1, operand2, evaluate=self.evaluate) if ctx.SLASH(): return ( operand1 / operand2 if self.evaluate else sp.Mul( operand1, sp.Pow(operand2, -1, evaluate=False), evaluate=False, ) ) raise AssertionError(f"Unexpected expression: {ctx.getText()}")
[docs] def visitAddExpr(self, ctx: PetabMathExprParser.AddExprContext) -> sp.Expr: """Convert addition and subtraction expressions to sympy.""" op1 = bool2num(self.visit(ctx.getChild(0))) op2 = bool2num(self.visit(ctx.getChild(2))) if ctx.PLUS(): return sp.Add(op1, op2, evaluate=self.evaluate) if ctx.MINUS(): return sp.Add(op1, -op2, evaluate=self.evaluate) raise AssertionError( f"Unexpected operator: {ctx.getChild(1).getText()} " f"in {ctx.getText()}" )
[docs] def visitArgumentList( self, ctx: PetabMathExprParser.ArgumentListContext ) -> list[sp.Basic | sp.Expr]: """Convert function argument lists to a list of sympy expressions.""" return [self.visit(c) for c in ctx.children[::2]]
[docs] def visitFunctionCall( self, ctx: PetabMathExprParser.FunctionCallContext ) -> sp.Expr: """Convert function call to sympy expression.""" if ctx.getChildCount() < 4: raise AssertionError(f"Unexpected expression: {ctx.getText()}") func_name = ctx.getChild(0).getText() args = self.visit(ctx.getChild(2)) if func_name != "piecewise": # all functions except piecewise expect numerical arguments args = list(map(bool2num, args)) if func_name in _trig_funcs: if len(args) != 1: raise AssertionError( f"Unexpected number of arguments: {len(args)} " f"in {ctx.getText()}" ) return _trig_funcs[func_name](*args, evaluate=self.evaluate) if func_name in _unary_funcs: if len(args) != 1: raise AssertionError( f"Unexpected number of arguments: {len(args)} " f"in {ctx.getText()}" ) return _unary_funcs[func_name](*args, evaluate=self.evaluate) if func_name in _binary_funcs: if len(args) != 2: raise AssertionError( f"Unexpected number of arguments: {len(args)} " f"in {ctx.getText()}" ) return _binary_funcs[func_name](*args, evaluate=self.evaluate) if func_name == "log": if len(args) not in [1, 2]: raise AssertionError( f"Unexpected number of arguments: {len(args)} " f"in {ctx.getText()}" ) return ( -sp.oo if args[0].is_zero is True else sp.log(*args, evaluate=self.evaluate) ) if func_name == "piecewise": if (len(args) - 1) % 2 != 0: raise AssertionError( f"Unexpected number of arguments: {len(args)} " f"in {ctx.getText()}" ) # sympy's Piecewise requires an explicit condition for the final # `else` case args.append(sp.true) sp_args = ( (true_expr, num2bool(condition)) for true_expr, condition in zip( args[::2], args[1::2], strict=True ) ) return sp.Piecewise(*sp_args, evaluate=self.evaluate) raise ValueError(f"Unknown function: {ctx.getText()}")
[docs] def visitParenExpr(self, ctx: PetabMathExprParser.ParenExprContext): """Convert parenthesized expression to sympy.""" return self.visit(ctx.getChild(1))
[docs] def visitPowerExpr( self, ctx: PetabMathExprParser.PowerExprContext ) -> sp.Pow: """Convert power expression to sympy.""" if ctx.getChildCount() != 3: raise AssertionError( f"Unexpected number of children: {ctx.getChildCount()} " f"in {ctx.getText()}" ) operand1 = bool2num(self.visit(ctx.getChild(0))) operand2 = bool2num(self.visit(ctx.getChild(2))) return sp.Pow(operand1, operand2, evaluate=self.evaluate)
[docs] def visitUnaryExpr( self, ctx: PetabMathExprParser.UnaryExprContext ) -> sp.Basic | sp.Expr: """Convert unary expressions to sympy.""" if ctx.getChildCount() == 2: operand = bool2num(self.visit(ctx.getChild(1))) match ctx.getChild(0).getText(): case "-": return -operand case "+": return operand raise AssertionError(f"Unexpected expression: {ctx.getText()}")
[docs] def visitComparisonExpr( self, ctx: PetabMathExprParser.ComparisonExprContext ) -> sp.Basic | sp.Expr: """Convert comparison expressions to sympy.""" if ctx.getChildCount() != 3: raise AssertionError(f"Unexpected expression: {ctx.getText()}") lhs = self.visit(ctx.getChild(0)) op = ctx.getChild(1).getText() rhs = self.visit(ctx.getChild(2)) ops = { "==": sp.Equality, "!=": sp.Unequality, "<": sp.StrictLessThan, ">": sp.StrictGreaterThan, "<=": sp.LessThan, ">=": sp.GreaterThan, } if op in ops: lhs = bool2num(lhs) rhs = bool2num(rhs) return ops[op](lhs, rhs, evaluate=self.evaluate) raise AssertionError(f"Unexpected operator: {op}")
[docs] def visitBooleanNotExpr( self, ctx: PetabMathExprParser.BooleanNotExprContext ) -> sp.Basic | sp.Expr: """Convert boolean NOT expressions to sympy.""" if ctx.getChildCount() == 2: return ~num2bool(self.visit(ctx.getChild(1))) raise AssertionError(f"Unexpected expression: {ctx.getText()}")
[docs] def visitBooleanAndOrExpr( self, ctx: PetabMathExprParser.BooleanAndOrExprContext ) -> sp.Basic | sp.Expr: """Convert boolean AND and OR expressions to sympy.""" if ctx.getChildCount() != 3: raise AssertionError(f"Unexpected expression: {ctx.getText()}") operand1 = num2bool(self.visit(ctx.getChild(0))) operand2 = num2bool(self.visit(ctx.getChild(2))) if ctx.BOOLEAN_AND(): return operand1 & operand2 if ctx.BOOLEAN_OR(): return operand1 | operand2 raise AssertionError(f"Unexpected expression: {ctx.getText()}")
[docs] def visitBooleanLiteral( self, ctx: PetabMathExprParser.BooleanLiteralContext ) -> Boolean: """Convert boolean literals to sympy.""" if ctx.TRUE(): return sp.true if ctx.FALSE(): return sp.false raise AssertionError(f"Unexpected boolean literal: {ctx.getText()}")
def bool2num(x: sp.Basic | sp.Expr) -> sp.Basic | sp.Expr: """Convert sympy Booleans to Floats.""" if isinstance(x, BooleanFalse): return sp.Float(0) if isinstance(x, BooleanTrue): return sp.Float(1) return x def num2bool(x: sp.Basic | sp.Expr) -> sp.Basic | sp.Expr: """Convert sympy Floats to booleans.""" if isinstance(x, BooleanTrue | BooleanFalse): return x # Note: sp.Float(0) == 0 is False in sympy>=1.13 if x.is_zero is True: return sp.false if x.is_zero is False: return sp.true if isinstance(x, Boolean): return x return sp.Piecewise((True, x != 0.0), (False, True))