Changeset 964

Show
Ignore:
Timestamp:
05/24/11 00:12:23 (12 months ago)
Author:
emuller
Message:

Factoring out 9ML CG generic code from neuron backend in preparation for NEST support, as discussed with Andrew.

Location:
trunk
Files:
4 modified

Legend:

Unmodified
Added
Removed
  • trunk/examples/nineml_neuron.py

    r957 r964  
    1616from pyNN.utility import init_logging 
    1717 
     18import pyNN.common, pyNN.recording 
     19pyNN.common.simulator = sim.simulator 
     20pyNN.recording.simulator = sim.simulator 
     21 
    1822from copy import deepcopy 
    1923 
    2024init_logging(None, debug=True) 
    21 sim.setup(timestep=0.1, min_delay=0.1) 
     25sim.setup(timestep=0.1, min_delay=0.1, max_delay=2.0) 
    2226 
    2327celltype_cls = nineml_cell_type("if_cond_exp", 
     
    3842    'gL': 50.0, 
    3943    't_ref': 5.0, 
    40     'excitatory_tau': 2.0, 
    41     'inhibitory_tau': 5.0, 
     44    'excitatory_tau': 1.5, 
     45    'inhibitory_tau': 10.0, 
    4246    'excitatory_E': 0.0, 
    4347    'inhibitory_E': -70.0, 
     
    5458input = sim.Population(2, sim.SpikeSourcePoisson, {'rate': 100}) 
    5559 
    56 connector = sim.OneToOneConnector(weights=1.0, delays=0.5) 
     60connector = sim.OneToOneConnector(weights=1.0)#, delays=0.5) 
    5761conn = [sim.Projection(input[0:1], cells, connector, target='excitatory'), 
    5862        sim.Projection(input[1:2], cells, connector, target='inhibitory')] 
     
    6973cells.recorders['inhibitory_g'].write("Results/nineml_neuron.g_inh", filter=[cells[0]]) 
    7074 
     75t = cells.recorders['V'].get()[:,1] 
     76v = cells.recorders['V'].get()[:,2] 
     77g_exc = cells.recorders['excitatory_g'].get()[:,2] 
     78g_inh = cells.recorders['inhibitory_g'].get()[:,2] 
     79 
     80 
     81#plot(t,v) 
     82 
     83 
     84 
    7185sim.end() 
  • trunk/setup.py

    r963 r964  
    4343    version = "0.8.0dev", 
    4444    package_dir={'pyNN': 'src'}, 
    45     packages = ['pyNN','pyNN.nest', 'pyNN.pcsim', 'pyNN.neuron', 'pyNN.brian', 
     45    packages = ['pyNN','pyNN.nest', 'pyNN.pcsim', 'pyNN.neuron', 'pyNN.nineml', 
     46                'pyNN.brian', 
    4647                'pyNN.recording', 'pyNN.standardmodels', 'pyNN.descriptions', 
    4748                'pyNN.nest.standardmodels', 'pyNN.pcsim.standardmodels', 
  • trunk/src/neuron/nineml.py

    r963 r964  
    2323import neuron 
    2424from pyNN.models import BaseCellType 
    25 import nineml.abstraction_layer as nineml 
     25from pyNN.nineml.cells import _build_nineml_celltype 
    2626import logging 
    2727import os 
    28 import re 
    2928from itertools import chain 
    3029 
     
    9998 
    10099 
    101 def _add_prefix(synapse_model, prefix, port_map): 
    102     """ 
    103     Add a prefix to all variables in `synapse_model`, except for variables with 
    104     receive ports and specified in `port_map`. 
    105     """ 
    106     synapse_model.__cache__ = {} 
    107     exclude = [] 
    108     new_port_map = [] 
    109     for name1, name2 in port_map: 
    110         if synapse_model.ports_map[name2].mode == 'recv': 
    111             exclude.append(name2) 
    112             new_port_map.append((name1, name2)) 
    113         else: 
    114             new_port_map.append((name1, prefix + '_' + name2)) 
    115     synapse_model.add_prefix(prefix + '_', exclude=exclude) 
    116     return new_port_map 
    117  
    118  
    119 class _build_nineml_celltype(type): 
    120     """ 
    121     Metaclass for building NineMLCellType subclasses 
    122     """ 
    123     def __new__(cls, name, bases, dct): 
    124         # join the neuron and synapse components into a single component 
    125         combined_model = dct["neuron_model"] 
    126         for label in dct["synapse_models"].keys(): 
    127             port_map = dct["port_map"][label] 
    128             port_map = _add_prefix(dct["synapse_models"][label], label, port_map) 
    129             dct["weight_variables"][label] = label + "_" + dct["weight_variables"][label] 
    130             combined_model = join(combined_model, 
    131                                   dct["synapse_models"][label], 
    132                                   port_map, 
    133                                   name=name) 
    134         dct["combined_model"] = combined_model 
    135         # set class attributes required for a PyNN cell type class 
    136         dct["default_parameters"] = dict((name, 1.0) 
    137                                       for name in combined_model.parameters) 
    138         dct["default_initial_values"] = dict((name, 0.0) 
    139                                           for name in combined_model.state_variables) 
    140         dct["synapse_types"] = dct["synapse_models"].keys() #really need an ordered dict 
    141         dct["injectable"] = True # need to determine this. How?? 
    142         dct["recordable"] = [port.name for port in combined_model.analog_ports] + ['spikes', 'regime'] 
    143         dct["standard_receptor_type"] = (dct["synapse_types"] == ('excitatory', 'inhibitory')) 
    144         dct["conductance_based"] = True # how to determine this?? 
    145         dct["model_name"] = name 
    146         logger.debug("Creating class '%s' with bases %s and dictionary %s" % (name, bases, dct)) 
    147         # generate and compile NMODL code, then load the mechanism into NEUORN 
    148         _compile_nmodl(combined_model, dct["weight_variables"]) # weight variables should really be stored within combined_model 
    149         return type.__new__(cls, name, bases, dct) 
    150      
    151      
    152  
    153100def nineml_cell_type(name, neuron_model, port_map={}, weight_variables={}, **synapse_models): 
    154101    """ 
     
    159106                                   'synapse_models': synapse_models, 
    160107                                   'port_map': port_map, 
    161                                    'weight_variables': weight_variables}) 
     108                                   'weight_variables': weight_variables, 
     109                                   'builder': _compile_nmodl}) 
    162110 
    163111 
    164 def join(c1, c2, port_map=[], name=None): 
    165     """Create a NineML component by joining the two given components.""" 
    166     logger.debug("Joining components %s and %s with port map %s" % (c1, c2, port_map)) 
    167     logger.debug("New component will have name '%s'" % name) 
    168     # combine bindings from c1 and c2 
    169     bindings = {} 
    170     for b in chain(c1.bindings, c2.bindings): 
    171         bindings[b.name] = b 
    172     # combine ports (some will later be removed) 
    173     all_ports = c1.ports_map.copy() 
    174     all_ports.update(c2.ports_map) 
    175     # event ports do not be passed to the constructor, as they are attached to transitions 
    176     for port_name, port in all_ports.items(): 
    177         if isinstance(port, nineml.EventPort): 
    178             all_ports.pop(port_name) 
    179     # connect ports. 
    180     # currently, when ports are connected they disappear. It might be better to 
    181     # explicitly keep the ports in the new component but mark them as connected 
    182     for name1, name2 in port_map: 
    183         assert name1 in c1.ports_map, "%s is not in %s" % (name1, c1.ports_map.keys()) 
    184         assert name2 in c2.ports_map, "%s is not in %s" % (name2, c2.ports_map.keys()) 
    185  
    186         port1 = c1.ports_map[name1] 
    187         port2 = c2.ports_map[name2] 
    188         assert port1.mode != port2.mode 
    189         if port1.mode == 'send': 
    190             send_port = port1 
    191             recv_port = port2 
    192             send_port_name = name1 
    193             recv_port_name = name2 
    194         else: 
    195             send_port = port2 
    196             recv_port = port1 
    197             send_port_name = name2 
    198             recv_port_name = name1 
    199         # when connecting ports in which the send port has an expression, need 
    200         # to create a binding for this expression in the new component 
    201         if send_port.expr: 
    202             func_args = c1.non_parameter_symbols.union(c2.non_parameter_symbols).intersection(send_port.expr.names) 
    203             lhs = "%s(%s)" % (send_port_name, ",".join(func_args)) 
    204             send_binding = nineml.Binding(lhs, send_port.expr.rhs) 
    205             bindings[send_binding.name] = send_binding 
    206             for eq in chain(c1.equations, c2.equations): 
    207                 if send_port_name in eq.names: 
    208                     eq.rhs = eq.rhs_name_transform({send_port_name: lhs})             
    209             if recv_port.mode == 'reduce': 
    210                 # need to retain reduce ports as they can be connected to in a future join 
    211                 if recv_port_name in bindings: 
    212                     # this reduce port has already been connected to, so combine using its reduce_op 
    213                     reduce_binding = bindings[recv_port_name] 
    214                     func_args = func_args.union(reduce_binding.args) 
    215                     lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) 
    216                     rhs = recv_port.reduce_op.join([reduce_binding.rhs, send_binding.lhs]) 
    217                 else: 
    218                     # this is the first time this reduce port has been connected to 
    219                     lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) 
    220                     rhs = send_binding.lhs 
    221                 bindings[recv_port_name] = nineml.Binding(lhs, rhs) 
    222                 recv_port.connected = True 
    223             else: 
    224                 all_ports.pop(name1) 
    225         else: 
    226             if recv_port.mode == 'reduce': 
    227                 raise NotImplementedError 
    228             else: 
    229                 all_ports.pop(name1) 
    230  
    231         if name1 != name2: 
    232             #c2.substitute(name2, name1) # need to implement this. Currently this all only works if name1 == name2 
    233                                          # probably needs to happen sooner in the function 
    234             all_ports.pop(name2) 
    235  
    236     # where parameters have become bindings due to connecting ports, replace 
    237     # bare names with function calls in the equations 
    238     for bname, binding in bindings.items(): 
    239         for eq in chain(c1.equations, c2.equations): 
    240             if bname in eq.names: 
    241                 print "#### replacing %s by %s" % (bname, binding.lhs) 
    242                 pattern = re.compile(r'%s(\([\w\, ]*\))?' % bname) 
    243                 m = pattern.search(eq.rhs) 
    244                 if m: 
    245                     eq.rhs = pattern.sub(binding.lhs, eq.rhs) 
    246                 else: 
    247                     eq.rhs = eq.rhs_name_transform({bname: binding.lhs}) 
    248  
    249     # create new regimes from all possible combinations of the regimes from the 
    250     # two components 
    251     regime_map = {} 
    252     for r1 in c1.regimes: 
    253         regime_map[r1.name] = {} 
    254         for r2 in c2.regimes: 
    255             if r1.name == r2.name: 
    256                 new_name = r1.name 
    257             else: 
    258                 new_name = "%s_AND_%s" % (r1.name, r2.name) 
    259             kwargs = {'name': new_name} 
    260             new_regime = nineml.Regime(*r1.nodes.union(r2.nodes), **kwargs) 
    261             regime_map[r1.name][r2.name] = new_regime 
    262     # create transitions between all the new regimes 
    263     transitions = [] 
    264     for r1 in c1.regimes: 
    265         for r2 in c2.regimes: 
    266             for t in r1.transitions: 
    267                 new_transition = nineml.Transition(*t.nodes, 
    268                                                    from_=regime_map[r1.name][r2.name], 
    269                                                    to=regime_map[t.to.name][r2.name], 
    270                                                    condition=t.condition) 
    271                 transitions.append(new_transition) 
    272             for t in r2.transitions: 
    273                 new_transition = nineml.Transition(*t.nodes, 
    274                                                    from_=regime_map[r1.name][r2.name], 
    275                                                    to=regime_map[r1.name][t.to.name], 
    276                                                    condition=t.condition) 
    277                 transitions.append(new_transition) 
    278  
    279     regimes = [] 
    280     for d in regime_map.values(): 
    281         regimes.extend(d.values()) 
    282     name = name or "%s__%s" % (c1.name, c2.name) 
    283     return nineml.Component(name, 
    284                             regimes=regimes, 
    285                             transitions=transitions, 
    286                             ports=all_ports.values(), 
    287                             bindings=bindings.values()) 
  • trunk/src/nineml/cells.py

    r957 r964  
    66""" 
    77 
     8from __future__ import absolute_import 
    89from pyNN import standardmodels 
    9 import pyNN.cells 
     10import pyNN.standardmodels.cells as cells 
    1011import nineml.user_layer as nineml 
    11 from utility import build_parameter_set, catalog_url, map_random_distribution_parameters 
     12from pyNN.nineml.utility import build_parameter_set, catalog_url, map_random_distribution_parameters 
     13 
     14from pyNN.models import BaseCellType 
     15import nineml.abstraction_layer as nineml 
     16import logging 
     17import os 
     18import re 
     19from itertools import chain 
     20 
     21logger = logging.getLogger("PyNN") 
    1222 
    1323 
     
    4858 
    4959 
    50 class IF_curr_exp(pyNN.cells.IF_curr_exp, CellTypeMixin): 
     60class IF_curr_exp(cells.IF_curr_exp, CellTypeMixin): 
    5161    """Leaky integrate and fire model with fixed threshold and 
    5262    decaying-exponential post-synaptic current. (Separate synaptic currents for 
     
    7888 
    7989 
    80 class IF_cond_exp(pyNN.cells.IF_cond_exp, CellTypeMixin): 
     90class IF_cond_exp(cells.IF_cond_exp, CellTypeMixin): 
    8191    
    8292    translations = standardmodels.build_translations( 
     
    107117 
    108118 
    109 class IF_cond_alpha(pyNN.cells.IF_cond_exp, CellTypeMixin): 
     119class IF_cond_alpha(cells.IF_cond_exp, CellTypeMixin): 
    110120    
    111121    translations = standardmodels.build_translations( 
     
    136146     
    137147 
    138 class SpikeSourcePoisson(pyNN.cells.SpikeSourcePoisson, CellTypeMixin): 
     148class SpikeSourcePoisson(cells.SpikeSourcePoisson, CellTypeMixin): 
    139149     
    140150    translations = standardmodels.build_translations( 
     
    145155    spiking_mechanism_definition_url = "%s/neurons/poisson_spike_source.xml" % catalog_url 
    146156    spiking_mechanism_parameter_names = ("onset", "frequency", "duration") 
     157 
     158 
     159 
     160# Neuron Models derived from a 9ML AL definition 
     161 
     162class NineMLCellType(BaseCellType): 
     163    #model = NineMLCell 
     164     
     165    def __init__(self, parameters): 
     166        BaseCellType.__init__(self, parameters) 
     167        self.parameters["type"] = self 
     168 
     169 
     170def unimplemented_builder(*args, **kwargs): 
     171    raise NotImplementedError, "TODO: 9ML neuron builder" 
     172 
     173def nineml_cell_type(name, neuron_model, port_map={}, weight_variables={}, **synapse_models): 
     174    """ 
     175    Return a new NineMLCellType subclass. 
     176    """ 
     177    return _build_nineml_celltype(name, (NineMLCellType,), 
     178                                  {'neuron_model': neuron_model, 
     179                                   'synapse_models': synapse_models, 
     180                                   'port_map': port_map, 
     181                                   'weight_variables': weight_variables, 
     182                                   'builder': unimplemented_builder}) 
     183 
     184# Helpers for Neuron Models derived from a 9ML AL definition 
     185 
     186 
     187def _add_prefix(synapse_model, prefix, port_map): 
     188    """ 
     189    Add a prefix to all variables in `synapse_model`, except for variables with 
     190    receive ports and specified in `port_map`. 
     191    """ 
     192    synapse_model.__cache__ = {} 
     193    exclude = [] 
     194    new_port_map = [] 
     195    for name1, name2 in port_map: 
     196        if synapse_model.ports_map[name2].mode == 'recv': 
     197            exclude.append(name2) 
     198            new_port_map.append((name1, name2)) 
     199        else: 
     200            new_port_map.append((name1, prefix + '_' + name2)) 
     201    synapse_model.add_prefix(prefix + '_', exclude=exclude) 
     202    return new_port_map 
     203 
     204 
     205class _build_nineml_celltype(type): 
     206    """ 
     207    Metaclass for building NineMLCellType subclasses 
     208    """ 
     209    def __new__(cls, name, bases, dct): 
     210        # join the neuron and synapse components into a single component 
     211        combined_model = dct["neuron_model"] 
     212        for label in dct["synapse_models"].keys(): 
     213            port_map = dct["port_map"][label] 
     214            port_map = _add_prefix(dct["synapse_models"][label], label, port_map) 
     215            dct["weight_variables"][label] = label + "_" + dct["weight_variables"][label] 
     216            combined_model = join(combined_model, 
     217                                  dct["synapse_models"][label], 
     218                                  port_map, 
     219                                  name=name) 
     220        dct["combined_model"] = combined_model 
     221        # set class attributes required for a PyNN cell type class 
     222        dct["default_parameters"] = dict((name, 1.0) 
     223                                      for name in combined_model.parameters) 
     224        dct["default_initial_values"] = dict((name, 0.0) 
     225                                          for name in combined_model.state_variables) 
     226        dct["synapse_types"] = dct["synapse_models"].keys() #really need an ordered dict 
     227        dct["injectable"] = True # need to determine this. How?? 
     228        dct["recordable"] = [port.name for port in combined_model.analog_ports] + ['spikes', 'regime'] 
     229        dct["standard_receptor_type"] = (dct["synapse_types"] == ('excitatory', 'inhibitory')) 
     230        dct["conductance_based"] = True # how to determine this?? 
     231        dct["model_name"] = name 
     232        logger.debug("Creating class '%s' with bases %s and dictionary %s" % (name, bases, dct)) 
     233        # generate and compile NMODL code, then load the mechanism into NEUORN 
     234        dct["builder"](combined_model, dct["weight_variables"]) # weight variables should really be stored within combined_model 
     235        return type.__new__(cls, name, bases, dct) 
     236     
     237     
     238 
     239def join(c1, c2, port_map=[], name=None): 
     240    """Create a NineML component by joining the two given components.""" 
     241    logger.debug("Joining components %s and %s with port map %s" % (c1, c2, port_map)) 
     242    logger.debug("New component will have name '%s'" % name) 
     243    # combine bindings from c1 and c2 
     244    bindings = {} 
     245    for b in chain(c1.bindings, c2.bindings): 
     246        bindings[b.name] = b 
     247    # combine ports (some will later be removed) 
     248    all_ports = c1.ports_map.copy() 
     249    all_ports.update(c2.ports_map) 
     250    # event ports do not be passed to the constructor, as they are attached to transitions 
     251    for port_name, port in all_ports.items(): 
     252        if isinstance(port, nineml.EventPort): 
     253            all_ports.pop(port_name) 
     254    # connect ports. 
     255    # currently, when ports are connected they disappear. It might be better to 
     256    # explicitly keep the ports in the new component but mark them as connected 
     257    for name1, name2 in port_map: 
     258        assert name1 in c1.ports_map, "%s is not in %s" % (name1, c1.ports_map.keys()) 
     259        assert name2 in c2.ports_map, "%s is not in %s" % (name2, c2.ports_map.keys()) 
     260 
     261        port1 = c1.ports_map[name1] 
     262        port2 = c2.ports_map[name2] 
     263        assert port1.mode != port2.mode 
     264        if port1.mode == 'send': 
     265            send_port = port1 
     266            recv_port = port2 
     267            send_port_name = name1 
     268            recv_port_name = name2 
     269        else: 
     270            send_port = port2 
     271            recv_port = port1 
     272            send_port_name = name2 
     273            recv_port_name = name1 
     274        # when connecting ports in which the send port has an expression, need 
     275        # to create a binding for this expression in the new component 
     276        if send_port.expr: 
     277            func_args = c1.non_parameter_symbols.union(c2.non_parameter_symbols).intersection(send_port.expr.names) 
     278            lhs = "%s(%s)" % (send_port_name, ",".join(func_args)) 
     279            send_binding = nineml.Binding(lhs, send_port.expr.rhs) 
     280            bindings[send_binding.name] = send_binding 
     281            for eq in chain(c1.equations, c2.equations): 
     282                if send_port_name in eq.names: 
     283                    eq.rhs = eq.rhs_name_transform({send_port_name: lhs})             
     284            if recv_port.mode == 'reduce': 
     285                # need to retain reduce ports as they can be connected to in a future join 
     286                if recv_port_name in bindings: 
     287                    # this reduce port has already been connected to, so combine using its reduce_op 
     288                    reduce_binding = bindings[recv_port_name] 
     289                    func_args = func_args.union(reduce_binding.args) 
     290                    lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) 
     291                    rhs = recv_port.reduce_op.join([reduce_binding.rhs, send_binding.lhs]) 
     292                else: 
     293                    # this is the first time this reduce port has been connected to 
     294                    lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) 
     295                    rhs = send_binding.lhs 
     296                bindings[recv_port_name] = nineml.Binding(lhs, rhs) 
     297                recv_port.connected = True 
     298            else: 
     299                all_ports.pop(name1) 
     300        else: 
     301            if recv_port.mode == 'reduce': 
     302                raise NotImplementedError 
     303            else: 
     304                all_ports.pop(name1) 
     305 
     306        if name1 != name2: 
     307            #c2.substitute(name2, name1) # need to implement this. Currently this all only works if name1 == name2 
     308                                         # probably needs to happen sooner in the function 
     309            all_ports.pop(name2) 
     310 
     311    # where parameters have become bindings due to connecting ports, replace 
     312    # bare names with function calls in the equations 
     313    for bname, binding in bindings.items(): 
     314        for eq in chain(c1.equations, c2.equations): 
     315            if bname in eq.names: 
     316                print "#### replacing %s by %s" % (bname, binding.lhs) 
     317                pattern = re.compile(r'%s(\([\w\, ]*\))?' % bname) 
     318                m = pattern.search(eq.rhs) 
     319                if m: 
     320                    eq.rhs = pattern.sub(binding.lhs, eq.rhs) 
     321                else: 
     322                    eq.rhs = eq.rhs_name_transform({bname: binding.lhs}) 
     323 
     324    # create new regimes from all possible combinations of the regimes from the 
     325    # two components 
     326    regime_map = {} 
     327    for r1 in c1.regimes: 
     328        regime_map[r1.name] = {} 
     329        for r2 in c2.regimes: 
     330            if r1.name == r2.name: 
     331                new_name = r1.name 
     332            else: 
     333                new_name = "%s_AND_%s" % (r1.name, r2.name) 
     334            kwargs = {'name': new_name} 
     335            new_regime = nineml.Regime(*r1.nodes.union(r2.nodes), **kwargs) 
     336            regime_map[r1.name][r2.name] = new_regime 
     337    # create transitions between all the new regimes 
     338    transitions = [] 
     339    for r1 in c1.regimes: 
     340        for r2 in c2.regimes: 
     341            for t in r1.transitions: 
     342                new_transition = nineml.Transition(*t.nodes, 
     343                                                   from_=regime_map[r1.name][r2.name], 
     344                                                   to=regime_map[t.to.name][r2.name], 
     345                                                   condition=t.condition) 
     346                transitions.append(new_transition) 
     347            for t in r2.transitions: 
     348                new_transition = nineml.Transition(*t.nodes, 
     349                                                   from_=regime_map[r1.name][r2.name], 
     350                                                   to=regime_map[r1.name][t.to.name], 
     351                                                   condition=t.condition) 
     352                transitions.append(new_transition) 
     353 
     354    regimes = [] 
     355    for d in regime_map.values(): 
     356        regimes.extend(d.values()) 
     357    name = name or "%s__%s" % (c1.name, c2.name) 
     358    return nineml.Component(name, 
     359                            regimes=regimes, 
     360                            transitions=transitions, 
     361                            ports=all_ports.values(), 
     362                            bindings=bindings.values()) 
     363 
     364 
     365 
     366 
     367 
     368 
     369 
     370 
     371