| 1 | | from __future__ import absolute_import |
| 2 | | import subprocess |
| 3 | | import neuron |
| 4 | | from pyNN.models import BaseCellType |
| 5 | | import nineml.abstraction_layer as nineml |
| 6 | | import logging |
| 7 | | import os |
| 8 | | from itertools import chain |
| 9 | | |
| 10 | | h = neuron.h |
| 11 | | logger = logging.getLogger("PyNN") |
| 12 | | |
| 13 | | NMODL_DIR = "nineml_mechanisms" |
| 14 | | |
| 15 | | class NineMLCell(object): |
| 16 | | |
| 17 | | def __init__(self, **parameters): |
| 18 | | self.type = parameters.pop("type") |
| 19 | | self.source_section = h.Section() |
| 20 | | self.source = getattr(h, self.type.model_name)(0.5, sec=self.source_section) |
| 21 | | for param, value in parameters.items(): |
| 22 | | setattr(self.source, param, value) |
| 23 | | # for recording |
| 24 | | self.spike_times = h.Vector(0) |
| 25 | | self.traces = {} |
| 26 | | self.recording_time = False |
| 27 | | |
| 28 | | def __getattr__(self, name): |
| 29 | | try: |
| 30 | | return self.__getattribute__(name) |
| 31 | | except AttributeError: |
| 32 | | if name in self.type.synapse_types: |
| 33 | | return self.source # source is also target |
| 34 | | else: |
| 35 | | raise AttributeError("'NineMLCell' object has no attribute or synapse type '%s'" % name) |
| 36 | | |
| 37 | | def record(self, active): |
| 38 | | if active: |
| 39 | | rec = h.NetCon(self.source, None) |
| 40 | | rec.record(self.spike_times) |
| 41 | | else: |
| 42 | | self.spike_times = h.Vector(0) |
| 43 | | |
| 44 | | def memb_init(self): |
| 45 | | # this is a bit of a hack |
| 46 | | for var in self.type.recordable: |
| 47 | | if hasattr(self, "%s_init" % var): |
| 48 | | initial_value = getattr(self, "%s_init" % var) |
| 49 | | logger.debug("Initialising %s to %g" % (var, initial_value)) |
| 50 | | setattr(self.source, var, initial_value) |
| 51 | | |
| 52 | | |
| 53 | | class NineMLCellType(BaseCellType): |
| 54 | | model = NineMLCell |
| 55 | | |
| 56 | | def __init__(self, parameters): |
| 57 | | BaseCellType.__init__(self, parameters) |
| 58 | | self.parameters["type"] = self |
| 59 | | |
| 60 | | |
| 61 | | def _compile_nmodl(nineml_component): |
| 62 | | if not os.path.exists(NMODL_DIR): |
| 63 | | os.makedirs(NMODL_DIR) |
| 64 | | cwd = os.getcwd() |
| 65 | | os.chdir(NMODL_DIR) |
| 66 | | xml_file = "%s.xml" % nineml_component.name |
| 67 | | logger.debug("Writing NineML component to %s" % xml_file) |
| 68 | | nineml_component.write(xml_file) |
| 69 | | nineml2nmodl = __import__("9ml2nmodl") |
| 70 | | nineml2nmodl.write_nmodl(xml_file) |
| 71 | | p = subprocess.check_call(["nrnivmodl"]) |
| 72 | | os.chdir(cwd) |
| 73 | | neuron.load_mechanisms(NMODL_DIR) |
| 74 | | |
| 75 | | |
| 76 | | class _build_nineml_celltype(type): |
| 77 | | """ |
| 78 | | Metaclass for building NineMLCellType subclasses |
| 79 | | """ |
| 80 | | def __new__(cls, name, bases, dct): |
| 81 | | assert len(dct["synapse_models"]) == 1, "For now, can't handle multiple synapse models" |
| 82 | | combined_model = join(dct["neuron_model"], |
| 83 | | dct["synapse_models"].values()[0], |
| 84 | | dct["port_map"], |
| 85 | | name=name) |
| 86 | | dct["combined_model"] = combined_model |
| 87 | | dct["default_parameters"] = dict((name, 1.0) |
| 88 | | for name in combined_model.parameters) |
| 89 | | dct["default_initial_values"] = dict((name, 0.0) |
| 90 | | for name in combined_model.state_variables) |
| 91 | | dct["synapse_types"] = dct["synapse_models"].keys() #really need an ordered dict |
| 92 | | dct["injectable"] = True # need to determine this. How?? |
| 93 | | dct["recordable"] = [port.name for port in combined_model.analog_ports] + ['spikes'] |
| 94 | | dct["standard_receptor_type"] = (dct["synapse_types"] == ('excitatory', 'inhibitory')) |
| 95 | | dct["conductance_based"] = True # how to determine this?? |
| 96 | | dct["model_name"] = name |
| 97 | | logger.debug("Creating class '%s' with bases %s and dictionary %s" % (name, bases, dct)) |
| 98 | | _compile_nmodl(combined_model) |
| 99 | | return type.__new__(cls, name, bases, dct) |
| 100 | | |
| 101 | | |
| 102 | | |
| 103 | | def nineml_cell_type(name, neuron_model, port_map={}, **synapse_models): |
| 104 | | """ |
| 105 | | Return a new NineMLCellType subclass. |
| 106 | | """ |
| 107 | | return _build_nineml_celltype(name, (NineMLCellType,), |
| 108 | | {'neuron_model': neuron_model, |
| 109 | | 'synapse_models': synapse_models, |
| 110 | | 'port_map': port_map}) |
| 111 | | |
| 112 | | |
| 113 | | def join(c1, c2, port_map=[], name=None): |
| 114 | | """Create a NineML component by joining the two given components.""" |
| 115 | | logger.debug("Joining components %s and %s with port map %s" % (c1, c2, port_map)) |
| 116 | | logger.debug("New component will have name '%s'" % name) |
| 117 | | bindings = [] # TODO: combine bindings from c1 and c2 |
| 118 | | all_ports = c1.ports_map.copy() |
| 119 | | all_ports.update(c2.ports_map) |
| 120 | | for port_name, port in all_ports.items(): |
| 121 | | if isinstance(port, nineml.EventPort): |
| 122 | | all_ports.pop(port_name) |
| 123 | | for name1, name2 in port_map: |
| 124 | | assert name1 in c1.ports_map |
| 125 | | assert name2 in c2.ports_map |
| 126 | | all_ports.pop(name1) |
| 127 | | if name1 != name2: |
| 128 | | #c2.substitute(name2, name1) # need to implement this |
| 129 | | all_ports.pop(name2) |
| 130 | | |
| 131 | | port1 = c1.ports_map[name1] |
| 132 | | port2 = c2.ports_map[name2] |
| 133 | | assert port1.mode != port2.mode |
| 134 | | if port1.mode == 'send': |
| 135 | | send_port = port1 |
| 136 | | port_name = name1 |
| 137 | | else: |
| 138 | | send_port = port2 |
| 139 | | port_name = name2 |
| 140 | | if send_port.expr: |
| 141 | | func_args = c1.non_parameter_symbols.union(c2.non_parameter_symbols).intersection(send_port.expr.names) |
| 142 | | lhs = "%s(%s)" % (port_name, ",".join(func_args)) |
| 143 | | bindings.append(nineml.Binding(lhs, send_port.expr.rhs)) |
| 144 | | for eq in chain(c1.equations, c2.equations): |
| 145 | | if port_name in eq.names: |
| 146 | | eq.rhs = eq.rhs_name_transform({port_name: lhs}) |
| 147 | | regime_map = {} |
| 148 | | for r1 in c1.regimes: |
| 149 | | regime_map[r1.name] = {} |
| 150 | | for r2 in c2.regimes: |
| 151 | | kwargs = {'name': "%s_AND_%s" % (r1.name, r2.name)} |
| 152 | | new_regime = nineml.Regime(*r1.nodes.union(r2.nodes), **kwargs) |
| 153 | | regime_map[r1.name][r2.name] = new_regime |
| 154 | | transitions = [] |
| 155 | | for r1 in c1.regimes: |
| 156 | | for r2 in c2.regimes: |
| 157 | | for t in r1.transitions: |
| 158 | | new_transition = nineml.Transition(*t.nodes, |
| 159 | | from_=regime_map[r1.name][r2.name], |
| 160 | | to=regime_map[t.to.name][r2.name], |
| 161 | | condition=t.condition) |
| 162 | | transitions.append(new_transition) |
| 163 | | for t in r2.transitions: |
| 164 | | new_transition = nineml.Transition(*t.nodes, |
| 165 | | from_=regime_map[r1.name][r2.name], |
| 166 | | to=regime_map[r1.name][t.to.name], |
| 167 | | condition=t.condition) |
| 168 | | transitions.append(new_transition) |
| 169 | | regimes = [] |
| 170 | | for d in regime_map.values(): |
| 171 | | regimes.extend(d.values()) |
| 172 | | name = name or "%s__%s" % (c1.name, c2.name) |
| 173 | | return nineml.Component(name, |
| 174 | | regimes=regimes, |
| 175 | | transitions=transitions, |
| 176 | | ports=all_ports.values(), |
| 177 | | bindings=bindings) |
| 178 | | |
| 179 | | |
| 180 | | |
| 181 | | |
| 182 | | |