import time
from sympy import Derivative, Float
import chaste_codegen as cg
from chaste_codegen._rdf import (
OXMETA,
PYCMLMETA,
get_MultipleUsesAllowed_tags,
get_variables_transitively,
)
from chaste_codegen.model_with_conversions import (
CYTOSOLIC_CALCIUM_CONCENTRATION_INDEX,
MEMBRANE_VOLTAGE_INDEX,
get_equations_for,
state_var_key_order,
)
TIME_STAMP = time.strftime('%Y-%m-%d %H:%M:%S')
def get_variable_name(s, interface=False):
"""Get the correct variable name based on the variable and whether it should be in the chaste_interface."""
s_name = s.expr if isinstance(s, Derivative) else s
s_name = str(s_name).replace('$', '__')
if not s_name.startswith('_'):
s_name = '_' + s_name
if isinstance(s, Derivative):
return 'd_dt_chaste_interface_var' + s_name
elif interface:
return 'var_chaste_interface_' + s_name
else:
return 'var' + s_name
[docs]class ChasteModel(object):
""" Holds information about a cellml model for which chaste code is to be generated.
It also holds relevant formatted equations and derivatives.
Please Note: this calass cannot generate chaste code directly, instead use a subclass of the model type
"""
DEFAULT_EXTENSIONS = ('.hpp', '.cpp')
def __enter__(self):
""" Pre analysis preparation. Required to be able to use model in context (with).
Defined for overwriting in sub-classes if pre-processing is needed to use the model for code generation. """
return self
def __init__(self, model, file_name, **kwargs):
""" Initialise a ChasteModel instance
Arguments
``model``
A :class:`cellmlmanip.Model` object.
``file_name``
The name you want to give your generated files WITHOUT the .hpp and .cpp extension
(e.g. aslanidi_model_2009 leads to aslanidi_model_2009.cpp and aslanidi_model_2009.hpp)
"""
# Store default options
self.file_name = file_name
self._templates = []
self.generated_code = []
# Items to print as chaste_interface_...
self._in_interface = set()
self._model = model
# retrieve probabilities that don't stay in 0 ... 1 range and shouldn't be checked
not_quite_probabilities = \
set(get_variables_transitively(self._model, (OXMETA, 'not_a_probability_even_though_it_should_be')))
self._stimulus_equations = self._get_stimulus()
self.use_modifiers = kwargs.get('use_modifiers', False)
self._modifiers = self._model.modifiers if self.use_modifiers else ()
self._extended_ionic_vars = self._get_extended_ionic_vars()
self._derivative_equations = self._get_derivative_equations()
self._derivative_eqs_excl_voltage = self._get_derivative_eqs_excl_voltage()
self._derivative_eqs_voltage = self._get_derivative_eqs_voltage()
self._derived_quant = self._get_derived_quant()
self._derived_quant_eqs = self._get_derived_quant_eqs()
# Update items to print as chaste_interface_...
self._in_interface.update(self._model.state_vars)
self._in_interface.add(self._model.time_variable)
self._in_interface.add(self._model.i_ionic_lhs)
self._in_interface.update(self._model.stimulus_params)
# Sort before printing
# Sort the state variables, in similar order to pycml to prevent breaking existing code.
# V and cytosolic_calcium_concentration need to be set for the sorting
# the state variables, in similar order to pycml to prevent breaking existing code.
self._state_vars = sorted(self._model.state_vars,
key=lambda state_var: state_var_key_order(model, state_var))
self._modifiable_parameters = sorted(
self._model.modifiable_parameters,
key=lambda v: self._model.get_display_name(v, OXMETA, get_MultipleUsesAllowed_tags())
)
self._modifiable_parameter_lookup = {p: str(i) for i, p in enumerate(self._modifiable_parameters)}
# store indices of concentrations & probabilities
self.concentrations = set(get_variables_transitively(self._model, (OXMETA, 'Concentration')))
self.probabilities = set(get_variables_transitively(self._model,
(OXMETA, 'Probability'))) - not_quite_probabilities
# Printing
self._pre_print_hook()
self._add_printers()
self._formatted_state_vars, self._use_verify_state_variables = self._format_state_variables()
# dict of variables to pass to the jinja2 templates
self._vars_for_template = \
{'base_class': '',
'model_type': '',
# indicate how to declare state vars and values
'vector_decl': 'std::vector<double>&',
'converter_version': cg.__version__,
'model_name': self._model.name,
'file_name': self.file_name,
'class_name': kwargs.get('class_name', 'ModelFromCellMl'),
'header_ext': kwargs.get('header_ext', '.hpp'),
'dynamically_loadable': kwargs.get('dynamically_loadable', False),
'use_model_factory': kwargs.get('use_model_factory', False),
'cellml_base': kwargs.get('cellml_base', ''),
'modifiers': self._format_modifiers(),
'generation_date': TIME_STAMP,
'use_get_intracellular_calcium_concentration':
self._model.cytosolic_calcium_concentration_var in self._model.state_vars,
'membrane_voltage_index': MEMBRANE_VOLTAGE_INDEX,
'cytosolic_calcium_concentration_index':
self._state_vars.index(self._model.cytosolic_calcium_concentration_var)
if self._model.cytosolic_calcium_concentration_var in self._model.state_vars
else CYTOSOLIC_CALCIUM_CONCENTRATION_INDEX,
'modifiable_parameters': self._format_modifiable_parameters(),
'state_vars': self._formatted_state_vars,
'use_verify_state_variables': self._use_verify_state_variables,
'default_stimulus_equations': self._format_default_stimulus(),
'ionic_vars': self._format_ionic_vars(),
'y_derivatives': self._format_y_derivatives(),
'y_derivative_equations': self._format_derivative_equations(self._derivative_equations),
'free_variable': self._format_free_variable(),
'ode_system_information': self._format_system_info(),
'named_attributes': self._format_named_attributes(),
'derived_quantities': self._format_derived_quant(),
'derived_quantity_equations': self._format_derived_quant_eqs()}
def _get_initial_value(self, var):
"""Returns the initial value of a variable if it has one, 0 otherwise"""
# state vars have an initial value parameter defined
initial_value = 0
if var in self._model.state_vars:
initial_value = getattr(var, 'initial_value', 0)
else:
eqs = get_equations_for(self._model, (var,), filter_modifiable_parameters_lhs=False, optimise=False)
# If there is a defining equation, there should be just 1 equation and it should be of the form var = value
if len(eqs) == 1 and isinstance(eqs[0].rhs, Float):
initial_value = eqs[0].rhs
return initial_value
def _get_stimulus(self):
return self._model.stimulus_equations
def _get_extended_ionic_vars(self):
""" Get the equations defining the ionic derivatives and all dependant equations"""
return self._model.extended_ionic_vars
def _get_derivative_equations(self):
""" Get equations defining the derivatives including V (self._model.membrane_voltage_var)"""
return self._model.derivative_equations
def _get_derivative_eqs_excl_voltage(self):
""" Get equations defining the derivatives excluding V (self._model.membrane_voltage_var)"""
# stat with derivatives without voltage and add all equations used
eqs = set()
deriv_and_eqs = set(filter(lambda deriv: deriv.args[0] != self._model.membrane_voltage_var,
self._model.y_derivatives))
num_derivatives = -1
while num_derivatives < len(deriv_and_eqs):
num_derivatives = len(deriv_and_eqs)
eqs = set(filter(lambda eq: eq.lhs in deriv_and_eqs, self._derivative_equations))
deriv_and_eqs.update(self._model.find_variables_and_derivatives([eq.rhs for eq in eqs]))
return eqs
def _get_derivative_eqs_voltage(self):
""" Get equations defining the derivatives for V only (self._model.membrane_voltage_var)"""
# start with derivatives for V only and add all equations used
eqs = set()
derivatives = set(filter(lambda deriv: deriv.args[0] == self._model.membrane_voltage_var,
self._model.y_derivatives))
num_derivatives = -1
while num_derivatives < len(derivatives):
num_derivatives = len(derivatives)
eqs = set(filter(lambda eq: eq.lhs in derivatives, self._derivative_equations))
for eq in eqs:
derivatives.update(self._model.find_variables_and_derivatives((eq.rhs, )))
return eqs
def _get_derived_quant(self):
""" Get all derived quantities
Stimulus currents are ignored and the result is sorted by display name"""
tagged = set(self._model.get_variables_by_rdf((PYCMLMETA, 'derived-quantity'), 'yes', sort=False))
# Get annotated derived quantities excluding stimulus current params
annotated = set(filter(lambda q: q not in self._model.stimulus_params
and self._model.has_ontology_annotation(q, OXMETA),
self._model.get_derived_quantities(sort=False)))
return sorted(tagged | annotated,
key=lambda v: self._model.get_display_name(v, OXMETA, get_MultipleUsesAllowed_tags()))
def _get_derived_quant_eqs(self):
""" Get the defining equations for derived quantities"""
return get_equations_for(self._model, self._derived_quant)
def _pre_print_hook(self):
""" The method provides a hook for subclasses to be able to add additional computation
before printing of the output starts"""
pass
def _add_printers(self, lookup_table_function=lambda e: None):
""" Initialises Printers for outputting chaste code. """
# Printer for printing chaste regular variable assignments (var_ prefix)
# Print variables in interface as var_chaste_interface
# (state variables, time, lhs of default_stimulus eqs, i_ionic and lhs of y_derivatives)
# Print modifiable parameters as mParameters[index]
self._printer = \
cg.ChastePrinter(lambda variable:
get_variable_name(variable, variable in self._in_interface)
if variable not in self._model.modifiable_parameters
else self._print_modifiable_parameters(variable),
lambda deriv: get_variable_name(deriv),
lookup_table_function)
# Printer for printing variable in comments e.g. for ode system information
self._name_printer = cg.ChastePrinter(lambda variable: get_variable_name(variable))
def _print_rhs_with_modifiers(self, modifier, eq, modifiers_with_defining_eqs=set()):
""" Print modifiable parameters in the correct format for the model type"""
# Make sure printer doesn't print variables as modifiers if they are state vars or eq lhs
# as those are handled by _print_rhs_with_modifiers
modifier_printer = \
cg.ChastePrinter(lambda variable:
self._format_modifier(variable) + '->Calc(' +
self._printer.doprint(variable) + ', ' +
self._printer.doprint(self._model.time_variable) + ')'
if variable in self._modifiers and variable not in modifiers_with_defining_eqs
else self._printer.doprint(variable),
lambda deriv: self._printer.doprint(deriv),
self._printer.lookup_table_function)
if modifier in self._modifiers:
return self._format_modifier(modifier) + '->Calc(' + modifier_printer.doprint(eq) + ', ' + \
self._printer.doprint(self._model.time_variable) + ')'
return modifier_printer.doprint(eq)
def _format_modifier(self, var):
""" Formatting of modifier for printing"""
return 'mp_' + self._model.get_display_name(var, OXMETA, get_MultipleUsesAllowed_tags()) + '_modifier'
def _format_modifiers(self):
""" Format the modifiers for printing to chaste code"""
return [{'name': self._model.modifier_names[param],
'modifier': self._format_modifier(param)}
for param in self._modifiers]
def _print_modifiable_parameters(self, variable):
""" Print modifiable parameters in the correct format for the model type"""
return 'mParameters[' + self._modifiable_parameter_lookup[variable] + ']'
def _format_modifiable_parameters(self):
""" Format the modifiable parameter for printing to chaste code"""
return [{'units': self._model.units.format(self._model.units.evaluate_units(param)),
'comment_name': self._name_printer.doprint(param),
'name': self._model.get_display_name(param, OXMETA, get_MultipleUsesAllowed_tags()),
'initial_value': self._printer.doprint(self._get_initial_value(param))}
for param in self._modifiable_parameters]
def _format_rY_entry(self, index):
""" Formatting of rY entry for printing"""
return 'rY[' + str(index) + ']'
def _format_rY_lookup(self, index, var, use_modifier=True):
""" Formatting of rY lookup for printing"""
entry = self._format_rY_entry(index)
if use_modifier and var in self._modifiers:
entry = self._format_modifier(var) +\
'->Calc(' + entry + ', ' + self._printer.doprint(self._model.time_variable) + ')'
if var == self._model.membrane_voltage_var:
entry = '(mSetVoltageDerivativeToZero ? this->mFixedVoltage : ' + entry + ')'
return entry
def _format_state_variables(self):
""" Get equations defining the derivatives including V (self._model.membrane_voltage_var)"""
def get_range_annotation(subject, annotation_tag):
""" Get range-low and range-high annotation for to verify state variables. """
if subject.cmeta_id is not None:
range_annotation = self._model.get_rdf_annotations(subject='#' + subject.cmeta_id,
predicate=(PYCMLMETA, annotation_tag))
annotation = next(range_annotation, None)
assert next(range_annotation, None) is None, 'Expecting 0 or 1 range annotation'
if annotation is not None:
return float(annotation[2])
return ''
# Get all used variables for eqs for ionic variables to be able to indicate if a state var is used
ionic_var_variables = set()
for eq in self._extended_ionic_vars:
ionic_var_variables.update(eq.rhs.free_symbols)
# Get all used variables for y derivs to be able to indicate if a state var is used
y_deriv_variables = set()
for eq in self._derivative_equations:
y_deriv_variables.update(eq.rhs.free_symbols)
# Get all used variables for y derivs to be able to indicate if a state var is used
voltage_deriv_variables = set()
for eq in self._derivative_eqs_voltage:
voltage_deriv_variables.update(eq.rhs.free_symbols)
# Get all used variables for derivatives_excl_voltage to be able to indicate if a state var is used
deriv_excl_voltage_variables = set()
for eq in self._derivative_eqs_excl_voltage:
deriv_excl_voltage_variables.update(eq.rhs.free_symbols)
# Get all used variables for eqs for derived quantities variables to be able to indicate if a state var is used
derived_quant_variables = set(filter(lambda q: q in self._model.state_vars, self._derived_quant))
for eq in self._derived_quant_eqs:
derived_quant_variables.update(eq.rhs.free_symbols)
formatted_state_vars = \
[{'var': self._printer.doprint(var[1]),
'annotated_var_name': self._model.get_display_name(var[1], OXMETA, get_MultipleUsesAllowed_tags()),
'rY_lookup': self._format_rY_lookup(var[0], var[1]),
'rY_lookup_no_modifier': self._format_rY_lookup(var[0], var[1], use_modifier=False),
'initial_value': str(self._get_initial_value(var[1])),
'modifier': self._format_modifier(var[1]) if var[1] in self._modifiers else None,
'units': self._model.units.format(self._model.units.evaluate_units(var[1])),
'in_ionic': var[1] in ionic_var_variables,
'in_y_deriv': var[1] in y_deriv_variables,
'in_deriv_excl_voltage': var[1] in deriv_excl_voltage_variables,
'in_voltage_deriv': var[1] in voltage_deriv_variables,
'in_derived_quant': var[1] in derived_quant_variables,
'range_low': get_range_annotation(var[1], 'range-low'),
'range_high': get_range_annotation(var[1], 'range-high'),
'sympy_var': var[1],
'state_var_index': self._state_vars.index(var[1]),
'is_concentration': var[1] in self.concentrations,
'is_probability': var[1] in self.probabilities}
for var in enumerate(self._state_vars)]
use_verify_state_variables = next(filter(lambda eq: eq['range_low'] != '' or eq['range_high'] != '',
formatted_state_vars), None) is not None
return (formatted_state_vars, use_verify_state_variables)
def _format_default_stimulus(self):
""" Format eqs for stimulus_current for outputting to chaste code"""
default_stim = {'equations':
[{'lhs': self._printer.doprint(eq.lhs),
'rhs': self._printer.doprint(eq.rhs),
'units': self._model.units.format(self._model.units.evaluate_units(eq.lhs)),
'lhs_modifiable': eq.lhs in self._model.modifiable_parameters}
for eq in self._stimulus_equations]}
for param in self._model.stimulus_params:
default_stim[self._model.get_display_name(param,
OXMETA,
get_MultipleUsesAllowed_tags())] = self._printer.doprint(param)
return default_stim
def _format_ionic_vars(self):
""" Format equations and dependant equations ionic derivatives"""
# Format the state ionic variables
return [{'lhs': self._printer.doprint(eq.lhs), 'rhs': self._printer.doprint(eq.rhs),
'units': self._model.units.format(self._model.units.evaluate_units(eq.lhs))}
for eq in self._extended_ionic_vars]
def _format_y_derivatives(self):
""" Format y_derivatives for writing to chaste output"""
self._in_interface.update(self._model.y_derivatives)
return [self._printer.doprint(deriv) for deriv in self._model.y_derivatives]
def _format_derivative_equations(self, derivative_equations):
""" Format derivative equations for chaste output"""
# Make sure printer doesn't print variables as modifiers if they are state vars or eq lhs
# as those are handled by _print_rhs_with_modifiers
modifiers_with_defining_eqs = set((eq.lhs for eq in derivative_equations)) | self._model.state_vars
# exclude ionic currents
formatted_deriv_eqs = [self.format_derivative_equation(eq, modifiers_with_defining_eqs)
for eq in derivative_equations]
return formatted_deriv_eqs
def _format_free_variable(self):
""" Format free variable for chaste output"""
return {'name': self._model.get_display_name(self._model.time_variable, OXMETA, get_MultipleUsesAllowed_tags()),
'units': self._model.units.format(self._model.units.evaluate_units(self._model.time_variable)),
'system_name': self._model.name,
'var_name': self._printer.doprint(self._model.time_variable)}
def _format_system_info(self):
""" Format general ode system info for chaste output"""
return [{'name': self._model.get_display_name(var[1], OXMETA, get_MultipleUsesAllowed_tags()),
'initial_value': str(self._get_initial_value(var[1])),
'units': self._model.units.format(self._model.units.evaluate_units(var[1])),
'rY_lookup': self._format_rY_entry(var[0])}
for var in enumerate(self._state_vars)]
def _format_named_attributes(self):
""" Format named attributes for chaste output"""
named_attributes = []
named_attrs = self._model.get_rdf_annotations(subject=self._model.rdf_identity,
predicate=(PYCMLMETA, 'named-attribute'))
for s, p, attr in named_attrs:
name = self._model.get_rdf_value(subject=attr, predicate=(PYCMLMETA, 'name'))
value = self._model.get_rdf_value(subject=attr, predicate=(PYCMLMETA, 'value'))
named_attributes.append({'name': name, 'value': value})
named_attributes.sort(key=lambda a: a['name'])
return named_attributes
def _format_derived_quant(self):
return [{'units': self._model.units.format(self._model.units.evaluate_units(quant)),
'var': self._printer.doprint(quant),
'name': self._model.get_display_name(quant, OXMETA, get_MultipleUsesAllowed_tags())}
for quant in self._derived_quant]
def _format_derived_quant_eqs(self):
""" Format equations for derived quantities based on current settings"""
modifiers_with_defining_eqs = set((eq.lhs for eq in self._derived_quant_eqs)) | self._model.state_vars
formatted_eq = [{'lhs': self._printer.doprint(eq.lhs),
'rhs': self._print_rhs_with_modifiers(eq.lhs, eq.rhs, modifiers_with_defining_eqs),
'sympy_lhs': eq.lhs,
'units': self._model.units.format(str(self._model.units.evaluate_units(eq.lhs)))}
for eq in self._derived_quant_eqs]
return formatted_eq
[docs] def generate_chaste_code(self):
""" Generates and stores chaste code"""
for templ in self._templates:
template = cg.load_template(templ)
self.generated_code.append(template.render(self._vars_for_template))
def __exit__(self, type, value, traceback):
""" Clean-up. Required to be able to use model in context (with).
Defined for overwriting in sub-classes if state needs to be reset for subsequent code generation. """
pass