Source code for chaste_codegen.cvode_chaste_model

from functools import partial

from sympy import (
    Derivative,
    Eq,
    Function,
    Matrix,
)

from chaste_codegen._jacobian import format_jacobian, get_jacobian
from chaste_codegen._partial_eval import partial_eval
from chaste_codegen._rdf import OXMETA, get_MultipleUsesAllowed_tags
from chaste_codegen.chaste_model import ChasteModel


[docs]class CvodeChasteModel(ChasteModel): """ Holds template and information specific for the CVODE model type""" def __init__(self, model, file_name, **kwargs): self._use_data_clamp = kwargs.get('cvode_data_clamp', False) # store if data clamp is needed self._use_analytic_jacobian = kwargs.get('use_analytic_jacobian', False) # store if jacobians are needed super().__init__(model, file_name, **kwargs) self._templates = ['cvode_model.hpp', 'cvode_model.cpp'] if self._use_data_clamp: self._vars_for_template['base_class'] = 'AbstractCvodeCellWithDataClamp' else: self._vars_for_template['base_class'] = 'AbstractCvodeCell' if self._use_data_clamp: self._vars_for_template['model_type'] = 'CvodeCellWithDataClamp' elif self._use_analytic_jacobian: self._vars_for_template['model_type'] = 'AnalyticCvode' else: self._vars_for_template['model_type'] = 'NumericCvode' self._vars_for_template['vector_decl'] = "N_Vector" # indicate how to declare state vars and values if self._use_analytic_jacobian: # get deriv eqs and substitute in all variables other than state vars self._derivative_equations = \ partial_eval(self._derivative_equations, self._model.y_derivatives, keep_multiple_usages=False) self._jacobian_equations, self._jacobian_matrix = get_jacobian(self._state_vars, self._derivative_equations) self._formatted_state_vars = self._update_state_vars() self._vars_for_template['jacobian_equations'], self._vars_for_template['jacobian_entries'] = \ self._print_jacobian() else: self._vars_for_template['jacobian_equations'], self._vars_for_template['jacobian_entries'] = \ [], Matrix() def _add_data_clamp_to_model(self): """ Add add membrane_data_clamp_current_conductance and membrane_data_clamp_current to the model""" self._membrane_data_clamp_current_conductance = \ self._model.add_variable(name='membrane_data_clamp_current_conductance', units=self._model.conversion_units.get_unit('dimensionless')) self.dataclamp_eq = Eq(self._membrane_data_clamp_current_conductance, 0.0) self._model.add_equation(self.dataclamp_eq) # add membrane_data_clamp_current self._membrane_data_clamp_current = \ self._model.add_variable(name='membrane_data_clamp_current', units=self._model.conversion_units.get_unit('uA_per_cm2')) # add clamp current equation self._in_interface.add(self._membrane_data_clamp_current) clamp_current = self._membrane_data_clamp_current_conductance * \ (self._model.membrane_voltage_var - Function('GetExperimentalVoltageAtTimeT')(self._model.time_variable)) self._membrane_data_clamp_current_eq = Eq(self._membrane_data_clamp_current, clamp_current) self._model.add_equation(self._membrane_data_clamp_current_eq) # Add data clamp current as modifiable parameter and re-sort self._model.modifiable_parameters.add(self._membrane_data_clamp_current_conductance) def _get_derivative_equations(self): """ Get equations defining the derivatives including V and add in membrane_data_clamp_current""" derivative_equations = super()._get_derivative_equations() if self._use_data_clamp: def find_ionic_var(eq, ionic_var, deqs): """Finds ionic_var on the rhs of eq, recursing through defining equations if necessary""" if ionic_var in eq.rhs.free_symbols: return deqs.index(eq) else: found_eq = None for var in eq.rhs.free_symbols: def_eqs = filter(lambda e: e.lhs == var, deqs) def_eq = next(def_eqs, None) if def_eq is not None and next(def_eqs, None) is None: # exactly 1 found_eq = find_ionic_var(def_eq, ionic_var, deqs) if found_eq is not None: break return found_eq # make a copy of the list of derivative eq, so that the underlying model can be reused derivative_equations = list([eq for eq in derivative_equations]) self._add_data_clamp_to_model() # piggy-backs on the analysis that finds ionic currents, in order to add in data clamp currents # Find dv/dt deriv_eq_only = filter(lambda eq: isinstance(eq.lhs, Derivative) and eq.lhs.args[0] == self._model.membrane_voltage_var, derivative_equations) dvdt = next(deriv_eq_only, None) assert dvdt is not None and next(deriv_eq_only, None) is None, 'Expecting exactly 1 dv/dt equation' current_index = None # We need to add data_clamp to the equation with the correct sign # This is achieved by substitution the first of the ionic currents # by (ionic_current + data_clamp_current) ionic_var = self._model.ionic_vars[0] current_index = find_ionic_var(dvdt, ionic_var, derivative_equations) if current_index is not None: eq = derivative_equations[current_index] rhs = eq.rhs.xreplace({ionic_var: (ionic_var + self._membrane_data_clamp_current)}) derivative_equations[current_index] = Eq(eq.lhs, rhs) derivative_equations.insert(current_index, self._membrane_data_clamp_current_eq) return derivative_equations def _get_derived_quant(self): """ Get all derived quantities, adds membrane_data_clamp_current and its defining equation""" derived_quant = super()._get_derived_quant() if self._use_data_clamp: # Add membrane_data_clamp_current to modifiable parameters # (this was set in _get_modifiable_parameters as it's also needed in _get_derivative_equations) derived_quant.append(self._membrane_data_clamp_current) derived_quant.sort(key=lambda q: self._model.get_display_name(q, OXMETA, get_MultipleUsesAllowed_tags())) return derived_quant def _format_derivative_equations(self, derivative_equations): """Format derivative equations for chaste output and add is_data_clamp_current flag""" formatted_eqs = super()._format_derivative_equations(derivative_equations) if self._use_data_clamp: for eq in formatted_eqs: eq['is_data_clamp_current'] = eq['sympy_lhs'] == self._membrane_data_clamp_current return formatted_eqs def _format_derived_quant_eqs(self): """ Format equations for derived quantities and add is_data_clamp_current flag""" formatted_eqs = super()._format_derived_quant_eqs() if self._use_data_clamp: for eq in formatted_eqs: eq['is_data_clamp_current'] = eq['sympy_lhs'] == self._membrane_data_clamp_current return formatted_eqs def _print_jacobian(self): modifiers_with_defining_eqs = set((eq[0] for eq in self._jacobian_equations)) | self._model.state_vars return format_jacobian(self._jacobian_equations, self._jacobian_matrix, self._printer, partial(self._print_rhs_with_modifiers, modifiers_with_defining_eqs=modifiers_with_defining_eqs)) def _print_modifiable_parameters(self, variable): return 'NV_Ith_S(mParameters, ' + self._modifiable_parameter_lookup[variable] + ')' def _format_rY_entry(self, index): return 'NV_Ith_S(rY, ' + str(index) + ')' def _update_state_vars(self): jacobian_symbols = set() for eq in self._jacobian_equations: jacobian_symbols.update(eq[1].free_symbols) for en in self._jacobian_matrix: jacobian_symbols.update(en.free_symbols) formatted_state_vars = self._formatted_state_vars for sv in formatted_state_vars: sv['in_jacobian'] = sv['sympy_var'] in jacobian_symbols return formatted_state_vars def __exit__(self, type, value, traceback): """ Clean-up, removed the data clamp if used, so that the model object can be re-used""" if self._use_data_clamp: # Remove modifiable parameter self._model.modifiable_parameters.remove(self._membrane_data_clamp_current_conductance) # Remove equation from model self._model.remove_equation(self.dataclamp_eq) self._model.remove_equation(self._membrane_data_clamp_current_eq) # Remove variable from model self._model.remove_variable(self._membrane_data_clamp_current_conductance) self._model.remove_variable(self._membrane_data_clamp_current)