Changeset 949 for trunk

Show
Ignore:
Timestamp:
04/20/11 19:27:45 (13 months ago)
Author:
pierre
Message:

Try to speed up the Nemo SpikeSourcePoisson? and SpikeSourceArray?. However, this is still slow, I guess du to communication between CPU and GPU.

Location:
trunk/src
Files:
3 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/nemo/simulator.py

    r946 r949  
    8686        if self.verbose: 
    8787            self.progressbar(simtime) 
     88 
     89        poissons_sources = [] 
     90        arrays_sources   = [] 
     91 
     92        for source in spikes_array_list: 
     93            if isinstance(source.celltype, SpikeSourcePoisson):         
     94                poissons_sources.append(source) 
     95            if isinstance(source.celltype, SpikeSourceArray): 
     96                arrays_sources.append(source) 
    8897         
    8998        for t in numpy.arange(0, simtime, self.dt): 
    9099            spikes   = [] 
    91100            currents = []  
    92             for source in spikes_array_list: 
    93                 if isinstance(source.celltype, SpikeSourcePoisson): 
    94                     if source.player.do_spike(t): 
    95                         spikes += [source] 
    96                 elif isinstance(source.celltype, SpikeSourceArray): 
    97                     if source.player.next_spike == t: 
    98                         source.player.update()                     
    99                         spikes += [source] 
    100  
     101            for source in poissons_sources: 
     102                if source.player.do_spike(t): 
     103                    spikes += [source] 
     104            for source in arrays_sources: 
     105                if source.player.next_spike == t: 
     106                    source.player.update()                     
     107                    spikes += [source] 
     108             
    101109            #for currents in current_sources: 
    102110            #    currents. 
  • trunk/src/nemo/standardmodels/cells.py

    r946 r949  
    6060            self.rate_Hz     = self.rate * self.precision/1000. 
    6161            self.stop_time   = self.start + self.duration 
     62            self.buffer      = 1000            
     63            self.do_spikes   = self.rng.rand(self.buffer) < self.rate_Hz 
     64            self.idx         = 0 
    6265     
    6366        def do_spike(self, t): 
    6467            if (t > self.stop_time) or (t < self.start): 
    6568                return False 
    66             return (self.rng.rand() < self.rate_Hz) 
     69            else: 
     70                if self.idx == (self.buffer - 1): 
     71                    self.do_spikes = self.rng.rand(self.buffer) < self.rate_Hz 
     72                    self.idx       = 0 
     73                self.idx += 1 
     74                return self.do_spikes[self.idx] 
    6775         
    6876        def reset(self, rate=None, start=None, duration=None, precision=1): 
     
    8795 
    8896    class spike_player(object): 
     97 
     98        precision = 1 
    8999         
    90         def __init__(self, spike_times=[], precision=1.): 
     100        def __init__(self, spike_times=[], precision=1): 
    91101            self.spike_times = precision * numpy.round(spike_times/precision)         
    92102            self.spike_times = numpy.unique(numpy.sort(self.spike_times)) 
     
    104114            self.cursor += 1 
    105115 
    106         def reset(self, spike_times): 
     116        def reset(self, spike_times, precision): 
    107117            self.spike_times = precision * numpy.round(spike_times/precision) 
    108118            self.spike_times = numpy.unique(numpy.sort(self.spike_times)) 
  • trunk/src/neuron/nineml.py

    r947 r949  
    1 from __future__ import absolute_import 
    2 import subprocess 
    3 import neuron 
    4 from pyNN.models import BaseCellType 
    5 import nineml.abstraction_layer as nineml 
    6 import logging 
    7 import os 
    8 from itertools import chain 
    9  
    10 h = neuron.h 
    11 logger = logging.getLogger("PyNN") 
    12  
    13 NMODL_DIR = "nineml_mechanisms" 
    14  
    15 class NineMLCell(object): 
    16      
    17     def __init__(self, **parameters): 
    18         self.type = parameters.pop("type") 
    19         self.source_section = h.Section() 
    20         self.source = getattr(h, self.type.model_name)(0.5, sec=self.source_section) 
    21         for param, value in parameters.items(): 
    22             setattr(self.source, param, value) 
    23         # for recording 
    24         self.spike_times = h.Vector(0) 
    25         self.traces = {} 
    26         self.recording_time = False 
    27      
    28     def __getattr__(self, name): 
    29         try: 
    30             return self.__getattribute__(name) 
    31         except AttributeError: 
    32             if name in self.type.synapse_types: 
    33                 return self.source # source is also target 
    34             else: 
    35                 raise AttributeError("'NineMLCell' object has no attribute or synapse type '%s'" % name) 
    36  
    37     def record(self, active): 
    38         if active: 
    39             rec = h.NetCon(self.source, None) 
    40             rec.record(self.spike_times) 
    41         else: 
    42             self.spike_times = h.Vector(0) 
    43  
    44     def memb_init(self): 
    45         # this is a bit of a hack 
    46         for var in self.type.recordable: 
    47             if hasattr(self, "%s_init" % var): 
    48                 initial_value = getattr(self, "%s_init" % var) 
    49                 logger.debug("Initialising %s to %g" % (var, initial_value)) 
    50                 setattr(self.source, var, initial_value) 
    51  
    52  
    53 class NineMLCellType(BaseCellType): 
    54     model = NineMLCell 
    55      
    56     def __init__(self, parameters): 
    57         BaseCellType.__init__(self, parameters) 
    58         self.parameters["type"] = self 
    59  
    60  
    61 def _compile_nmodl(nineml_component): 
    62     if not os.path.exists(NMODL_DIR): 
    63         os.makedirs(NMODL_DIR) 
    64     cwd = os.getcwd() 
    65     os.chdir(NMODL_DIR) 
    66     xml_file = "%s.xml" % nineml_component.name 
    67     logger.debug("Writing NineML component to %s" % xml_file) 
    68     nineml_component.write(xml_file) 
    69     nineml2nmodl = __import__("9ml2nmodl") 
    70     nineml2nmodl.write_nmodl(xml_file) 
    71     p = subprocess.check_call(["nrnivmodl"]) 
    72     os.chdir(cwd) 
    73     neuron.load_mechanisms(NMODL_DIR) 
    74  
    75  
    76 class _build_nineml_celltype(type): 
    77     """ 
    78     Metaclass for building NineMLCellType subclasses 
    79     """ 
    80     def __new__(cls, name, bases, dct): 
    81         assert len(dct["synapse_models"]) == 1, "For now, can't handle multiple synapse models" 
    82         combined_model = join(dct["neuron_model"], 
    83                               dct["synapse_models"].values()[0], 
    84                               dct["port_map"], 
    85                               name=name) 
    86         dct["combined_model"] = combined_model 
    87         dct["default_parameters"] = dict((name, 1.0) 
    88                                       for name in combined_model.parameters) 
    89         dct["default_initial_values"] = dict((name, 0.0) 
    90                                           for name in combined_model.state_variables) 
    91         dct["synapse_types"] = dct["synapse_models"].keys() #really need an ordered dict 
    92         dct["injectable"] = True # need to determine this. How?? 
    93         dct["recordable"] = [port.name for port in combined_model.analog_ports] + ['spikes'] 
    94         dct["standard_receptor_type"] = (dct["synapse_types"] == ('excitatory', 'inhibitory')) 
    95         dct["conductance_based"] = True # how to determine this?? 
    96         dct["model_name"] = name 
    97         logger.debug("Creating class '%s' with bases %s and dictionary %s" % (name, bases, dct)) 
    98         _compile_nmodl(combined_model) 
    99         return type.__new__(cls, name, bases, dct) 
    100      
    101      
    102  
    103 def nineml_cell_type(name, neuron_model, port_map={}, **synapse_models): 
    104     """ 
    105     Return a new NineMLCellType subclass. 
    106     """ 
    107     return _build_nineml_celltype(name, (NineMLCellType,), 
    108                                   {'neuron_model': neuron_model, 
    109                                    'synapse_models': synapse_models, 
    110                                    'port_map': port_map}) 
    111      
    112  
    113 def join(c1, c2, port_map=[], name=None): 
    114     """Create a NineML component by joining the two given components.""" 
    115     logger.debug("Joining components %s and %s with port map %s" % (c1, c2, port_map)) 
    116     logger.debug("New component will have name '%s'" % name) 
    117     bindings = [] # TODO: combine bindings from c1 and c2 
    118     all_ports = c1.ports_map.copy() 
    119     all_ports.update(c2.ports_map) 
    120     for port_name, port in all_ports.items(): 
    121         if isinstance(port, nineml.EventPort): 
    122             all_ports.pop(port_name) 
    123     for name1, name2 in port_map: 
    124         assert name1 in c1.ports_map 
    125         assert name2 in c2.ports_map 
    126         all_ports.pop(name1) 
    127         if name1 != name2: 
    128             #c2.substitute(name2, name1) # need to implement this 
    129             all_ports.pop(name2) 
    130          
    131         port1 = c1.ports_map[name1] 
    132         port2 = c2.ports_map[name2] 
    133         assert port1.mode != port2.mode 
    134         if port1.mode == 'send': 
    135             send_port = port1 
    136             port_name = name1 
    137         else: 
    138             send_port = port2 
    139             port_name = name2 
    140         if send_port.expr: 
    141             func_args = c1.non_parameter_symbols.union(c2.non_parameter_symbols).intersection(send_port.expr.names) 
    142             lhs = "%s(%s)" % (port_name, ",".join(func_args)) 
    143             bindings.append(nineml.Binding(lhs, send_port.expr.rhs)) 
    144             for eq in chain(c1.equations, c2.equations): 
    145                 if port_name in eq.names: 
    146                     eq.rhs = eq.rhs_name_transform({port_name: lhs}) 
    147     regime_map = {} 
    148     for r1 in c1.regimes: 
    149         regime_map[r1.name] = {} 
    150         for r2 in c2.regimes: 
    151             kwargs = {'name': "%s_AND_%s" % (r1.name, r2.name)} 
    152             new_regime = nineml.Regime(*r1.nodes.union(r2.nodes), **kwargs) 
    153             regime_map[r1.name][r2.name] = new_regime 
    154     transitions = [] 
    155     for r1 in c1.regimes: 
    156         for r2 in c2.regimes: 
    157             for t in r1.transitions: 
    158                 new_transition = nineml.Transition(*t.nodes, 
    159                                                    from_=regime_map[r1.name][r2.name], 
    160                                                    to=regime_map[t.to.name][r2.name], 
    161                                                    condition=t.condition) 
    162                 transitions.append(new_transition) 
    163             for t in r2.transitions: 
    164                 new_transition = nineml.Transition(*t.nodes, 
    165                                                    from_=regime_map[r1.name][r2.name], 
    166                                                    to=regime_map[r1.name][t.to.name], 
    167                                                    condition=t.condition) 
    168                 transitions.append(new_transition) 
    169     regimes = [] 
    170     for d in regime_map.values(): 
    171         regimes.extend(d.values()) 
    172     name = name or "%s__%s" % (c1.name, c2.name) 
    173     return nineml.Component(name, 
    174                             regimes=regimes, 
    175                             transitions=transitions, 
    176                             ports=all_ports.values(), 
    177                             bindings=bindings) 
    178  
    179  
    180  
    181          
    182