Source code for chaste_codegen._chaste_printer

from cellmlmanip.printer import Printer
from sympy import (
    Mul,
    Not,
    Piecewise,
    Pow,
    Rational,
    S,
)
from sympy.core.mul import _keep_coeff
from sympy.printing import cxxcode
from sympy.printing.precedence import precedence


C_MAX_INT = 2147483647
C_MIN_INT = -2147483647


[docs]class ChastePrinter(Printer): """ Converts Sympy expressions to strings for use in Chaste code generation. To use, create a :class:`ChastePrinter` instance, and call its method :meth:`doprint()` with a Sympy expression argument. Arguments: ``symbol_function`` A function that converts symbols to strings (variable names). ``derivative_function`` A function that converts derivatives to strings. ``lookup_table_function`` A function that prints lookup table expressions or returns None if the expression is not in the lookup table. """ _function_names = { 'abs_': 'fabs', 'acos_': 'acos', 'cos_': 'cos', 'exp_': 'exp', 'sqrt_': 'sqrt', 'sin_': 'sin', 'Abs': 'fabs', 'acos': 'acos', 'acosh': 'acosh', 'asin': 'asin', 'asinh': 'asinh', 'atan': 'atan', 'atan2': 'atan2', 'atanh': 'atanh', 'ceiling': 'ceil', 'cos': 'cos', 'cosh': 'cosh', 'exp': 'exp', 'expm1': 'expm1', 'factorial': 'factorial', 'floor': 'floor', 'log': 'log', 'log10': 'log10', 'log1p': 'log1p', 'log2': 'log2', 'sin': 'sin', 'sinh': 'sinh', 'sqrt': 'sqrt', 'tan': 'tan', 'tanh': 'tanh', 'sign': 'Signum', 'GetIntracellularAreaStimulus': 'GetIntracellularAreaStimulus', 'HeartConfig::Instance()->GetCapacitance': 'HeartConfig::Instance()->GetCapacitance', 'GetExperimentalVoltageAtTimeT': 'GetExperimentalVoltageAtTimeT' } _extra_trig_names = { 'sec': 'cos', 'csc': 'sin', 'cot': 'tan', 'sech': 'cosh', 'csch': 'sinh', 'coth': 'tanh', } _extra_inverse_trig_names = { 'asec': 'acos', 'acsc': 'asin', 'acot': 'atan', 'asech': 'acosh', 'acsch': 'asinh', 'acoth': 'atanh', } _literal_names = { 'e': 'e', 'nan': 'NAN', 'pi': 'M_PI', } def __init__(self, symbol_function=None, derivative_function=None, lookup_table_function=lambda e: None): super().__init__(symbol_function, derivative_function) self.lookup_table_function = lookup_table_function def _print(self, expr, **kwargs): """Internal dispatcher. Here we intercept lookup table expressions if we have lookup tables. Otherwise the base class method is used. """ printed_expr = self.lookup_table_function(expr) if printed_expr: return printed_expr return super()._print(expr, **kwargs) def _print_And(self, expr): """ Handles logical And. """ my_prec = precedence(expr) return ' && '.join(['(' + self._bracket(x, my_prec) + ')' for x in expr.args]) def _print_BooleanFalse(self, expr): """ Handles False """ return 'false' def _print_BooleanTrue(self, expr): """ Handles True """ return 'true' def _print_Or(self, expr): """ Handles logical Or. """ my_prec = precedence(expr) return ' || '.join(['(' + self._bracket(x, my_prec) + ')' for x in expr.args]) def _print_ordinary_pow(self, expr): """ Handles Pow(), handles just ordinary powers without division. For C++ printing we need to write ``x**y`` as ``pow(x, y)`` with lowercase ``p``.""" p = precedence(expr) if expr.exp == 0.5: return 'sqrt(' + self._bracket(expr.base, p) + ')' return 'pow(' + self._bracket(expr.base, p) + ', ' + self._bracket(expr.exp, p) + ')' def _print_ternary(self, cond, expr): parts = '' parts += '(' parts += self._print(cond) parts += ') ? (' parts += self._print(expr) parts += ') : (' return parts def _print_IntegerConstant(self, expr): return self._print_int(float(expr)) def _print_float(self, expr): """ Handles ``float``s. """ # print integers as int if they are between min & max int in c++ if expr.is_integer() and C_MIN_INT < expr < C_MAX_INT: return cxxcode(int(expr), standard='C++11') else: return cxxcode(float(expr), standard='C++11') def _print_int(self, expr): """ Handles ``ints``s. """ return self._print_float(float(expr)) def _print_ITE(self, expr): """ Handles ITE (if then else) objects by rewriting them as Piecewise """ return self._print_Piecewise(expr.rewrite(Piecewise)) def _print_Mul(self, expr): """ Handles multiplication & division, with n terms. Division is specified as a power: ``x / y --> x * y**-1``. Subtraction is specified as ``x - y --> x + (-1 * y)``. """ # This method is mostly copied from sympy.printing.Str # Check overall sign of multiplication sign = '' c, e = expr.as_coeff_Mul() if c < 0: expr = _keep_coeff(-c, e) sign = '-' # Collect all pows with more than one base element and exp = -1 pow_brackets = [] # Gather terms for numerator and denominator a, b = [], [] for item in Mul.make_args(expr): if item != 1.0: # In multiplications remove 1.0 * ... # Check if this is a negative power and it's not in a lookup table, so we can write it as a division if (item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative and not self.lookup_table_function(item)): if item.exp != -1: # E.g. x * y**(-2 / 3) --> x / y**(2 / 3) # Add as power b.append(Pow(item.base, -item.exp, evaluate=False)) else: # Add without power b.append(Pow(item.base, -item.exp)) # Check if it's a negative power that needs brackets # Sympy issue #14160 if (len(item.args[0].args) != 1 and isinstance(item.base, Mul)): pow_brackets.append(item) # Split Rationals over a and b, ignoring any 1s elif item.is_Rational: if item.p != 1: a.append(Rational(item.p)) if item.q != 1: b.append(Rational(item.q)) else: a.append(item) # Replace empty numerator with one a = a or [S.One] # Convert terms to code my_prec = precedence(expr) a_str = [self._bracket(x, my_prec) for x in a] b_str = [self._bracket(x, my_prec) for x in b] # Fix brackets for Pow with exp -1 with more than one Symbol for item in pow_brackets: assert item.base in b, "item.base should be kept in b for powers" b_str[b.index(item.base)] = '(' + b_str[b.index(item.base)] + ')' # Combine numerator and denomenator and return a_str = sign + ' * '.join(a_str) if len(b) == 0: return a_str b_str = ' * '.join(b_str) return a_str + ' / ' + (b_str if len(b) == 1 else '(' + b_str + ')') def _print_Not(self, expr): return '!(' + self._print(Not(expr)) + ')'