Changeset 713 for trunk/src/common.py

Show
Ignore:
Timestamp:
02/17/10 22:45:18 (2 years ago)
Author:
apdavison
Message:

Moved standard model (cells and synapses) machinery out of the common module into its own module

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/common.py

    r712 r713  
    99 
    1010Utility functions and classes: 
    11     is_listlike() 
    12     is_number() 
    1311    is_conductance() 
    1412    check_weight() 
    1513    check_delay() 
    16     distance() 
    1714     
    1815Accessing individual neurons: 
    1916    IDMixin 
    20  
    21 Standard cells/parameter translation: 
    22     build_translations()     
    23     StandardModelType 
    24     StandardCellType 
    25     ModelNotAvailable 
    2617     
    2718Common API implementation/base classes: 
     
    4637    Population 
    4738    Projection 
    48      
    49   4. Specification of synaptic plasticity: 
    50     SynapseDynamics 
    51     ShortTermPlasticityMechanism 
    52     STDPMechanism 
    53     STDPWeightDependence 
    54     STDPTimingDependence 
    5539 
    5640$Id$ 
    5741""" 
    5842 
    59 import types, copy, sys 
    6043import numpy 
    6144import logging 
    6245from math import * 
    6346import operator 
    64 from pyNN import random, utility, recording, errors 
     47from pyNN import random, utility, recording, errors, standardmodels, core 
    6548from string import Template 
    6649 
     
    7962# ============================================================================== 
    8063 
    81 def is_listlike(obj): 
    82     return hasattr(obj, "__len__") and not isinstance(obj, basestring) 
    83  
    84 def is_number(obj): 
    85     return isinstance(obj, float) or isinstance(obj, int) or isinstance(obj, long) or isinstance(obj, numpy.float64) 
    8664 
    8765def is_conductance(target_cell): 
     
    10381    if weight is None: 
    10482        weight = DEFAULT_WEIGHT 
    105     if is_listlike(weight): 
     83    if core.is_listlike(weight): 
    10684        weight = numpy.array(weight) 
    10785        nan_filter = (1-numpy.isnan(weight)).astype(bool) # weight arrays may contain NaN, which should be ignored 
     
    11189        if not (all_negative or all_positive): 
    11290            raise errors.InvalidWeightError("Weights must be either all positive or all negative") 
    113     elif is_number(weight): 
     91    elif numpy.isscalar(weight): 
    11492        all_positive = weight >= 0 
    11593        all_negative = weight < 0 
     
    224202     
    225203    def is_standard_cell(self): 
    226         return (type(self.cellclass) == type and issubclass(self.cellclass, StandardCellType)) 
     204        return (type(self.cellclass) == type and issubclass(self.cellclass, standardmodels.StandardCellType)) 
    227205         
    228206    def _set_position(self, pos): 
     
    267245        """Inject current from a current source object into the cell.""" 
    268246        current_source.inject_into([self]) 
    269          
    270  
    271 # ============================================================================== 
    272 #   Standard cells 
    273 # ============================================================================== 
    274  
    275 def build_translations(*translation_list): 
    276     """ 
    277     Build a translation dictionary from a list of translations/transformations. 
    278     """ 
    279     translations = {} 
    280     for item in translation_list: 
    281         assert 2 <= len(item) <= 4, "Translation tuples must have between 2 and 4 items. Actual content: %s" % str(item) 
    282         pynn_name = item[0] 
    283         sim_name = item[1] 
    284         if len(item) == 2: # no transformation 
    285             f = pynn_name 
    286             g = sim_name 
    287         elif len(item) == 3: # simple multiplicative factor 
    288             scale_factor = item[2] 
    289             f = "float(%g)*%s" % (scale_factor, pynn_name) 
    290             g = "%s/float(%g)" % (sim_name, scale_factor) 
    291         elif len(item) == 4: # more complex transformation 
    292             f = item[2] 
    293             g = item[3] 
    294         translations[pynn_name] = {'translated_name': sim_name, 
    295                                    'forward_transform': f, 
    296                                    'reverse_transform': g} 
    297     return translations 
    298  
    299 class StandardModelType(object): 
    300     """Base class for standardized cell model and synapse model classes.""" 
    301      
    302     translations = {} 
    303     default_parameters = {} 
    304      
    305     def __init__(self, parameters): 
    306         self.parameters = self.__class__.checkParameters(parameters, with_defaults=True) 
    307         self.parameters = self.__class__.translate(self.parameters) 
    308      
    309     @classmethod 
    310     def checkParameters(cls, supplied_parameters, with_defaults=False): 
    311         """ 
    312         Returns a parameter dictionary, checking that each 
    313         supplied_parameter is in the default_parameters and 
    314         converts to the type of the latter. 
    315  
    316         If with_defaults==True, parameters not in 
    317         supplied_parameters are in the returned dictionary 
    318         as in default_parameters. 
    319  
    320         """ 
    321         default_parameters = cls.default_parameters 
    322         if with_defaults: 
    323             parameters = copy.copy(default_parameters) 
    324         else: 
    325             parameters = {} 
    326         if supplied_parameters: 
    327             for k in supplied_parameters.keys(): 
    328                 if default_parameters.has_key(k): 
    329                     err_msg = "For %s in %s, expected %s, got %s (%s)" % \ 
    330                               (k, cls.__name__, type(default_parameters[k]), 
    331                                type(supplied_parameters[k]), supplied_parameters[k]) 
    332                     # same type 
    333                     if type(supplied_parameters[k]) == type(default_parameters[k]):  
    334                         parameters[k] = supplied_parameters[k] 
    335                     # float and something that can be converted to a float 
    336                     elif type(default_parameters[k]) == types.FloatType:  
    337                         try: 
    338                             parameters[k] = float(supplied_parameters[k])  
    339                         except (ValueError, TypeError): 
    340                             raise errors.InvalidParameterValueError(err_msg) 
    341                     # list and something that can be transformed to a list 
    342                     elif type(default_parameters[k]) == types.ListType: 
    343                         try: 
    344                             parameters[k] = list(supplied_parameters[k]) 
    345                         except TypeError: 
    346                             raise errors.InvalidParameterValueError(err_msg) 
    347                     else: 
    348                         raise errors.InvalidParameterValueError(err_msg) 
    349                 else: 
    350                     raise errors.NonExistentParameterError(k, cls) 
    351         return parameters 
    352      
    353     @classmethod 
    354     def translate(cls, parameters): 
    355         """Translate standardized model parameters to simulator-specific parameters.""" 
    356         parameters = cls.checkParameters(parameters, with_defaults=False) 
    357         native_parameters = {} 
    358         for name in parameters: 
    359             D = cls.translations[name] 
    360             pname = D['translated_name'] 
    361             if is_listlike(cls.default_parameters[name]): 
    362                 parameters[name] = numpy.array(parameters[name]) 
    363             try: 
    364                 pval = eval(D['forward_transform'], globals(), parameters) 
    365             except NameError, errmsg: 
    366                 raise NameError("Problem translating '%s' in %s. Transform: '%s'. Parameters: %s. %s" \ 
    367                                 % (pname, cls.__name__, D['forward_transform'], parameters, errmsg)) 
    368             except ZeroDivisionError: 
    369                 pval = 1e30 # this is about the highest value hoc can deal with 
    370             native_parameters[pname] = pval 
    371         return native_parameters 
    372      
    373     @classmethod 
    374     def reverse_translate(cls, native_parameters): 
    375         """Translate simulator-specific model parameters to standardized parameters.""" 
    376         standard_parameters = {} 
    377         for name,D  in cls.translations.items(): 
    378             if is_listlike(cls.default_parameters[name]): 
    379                 tname = D['translated_name'] 
    380                 native_parameters[tname] = numpy.array(native_parameters[tname]) 
    381             try: 
    382                 standard_parameters[name] = eval(D['reverse_transform'], {}, native_parameters) 
    383             except NameError, errmsg: 
    384                 raise NameError("Problem translating '%s' in %s. Transform: '%s'. Parameters: %s. %s" \ 
    385                                 % (name, cls.__name__, D['reverse_transform'], native_parameters, errmsg)) 
    386         return standard_parameters 
    387  
    388     @classmethod 
    389     def simple_parameters(cls): 
    390         """Return a list of parameters for which there is a one-to-one 
    391         correspondance between standard and native parameter values.""" 
    392         return [name for name in cls.translations if cls.translations[name]['forward_transform'] == name] 
    393  
    394     @classmethod 
    395     def scaled_parameters(cls): 
    396         """Return a list of parameters for which there is a unit change between 
    397         standard and native parameter values.""" 
    398         return [name for name in cls.translations if "float" in cls.translations[name]['forward_transform']] 
    399      
    400     @classmethod 
    401     def computed_parameters(cls): 
    402         """Return a list of parameters whose values must be computed from 
    403         more than one other parameter.""" 
    404         return [name for name in cls.translations if name not in cls.simple_parameters()+cls.scaled_parameters()] 
    405          
    406     def update_parameters(self, parameters): 
    407         """ 
    408         update self.parameters with those in parameters  
    409         """ 
    410         self.parameters.update(self.translate(parameters)) 
    411          
    412     def describe(self, template='standard'): 
    413         return str(self) 
    414      
    415  
    416 class StandardCellType(StandardModelType): 
    417     """Base class for standardized cell model classes.""" 
    418  
    419     recordable = ['spikes', 'v', 'gsyn'] 
    420     synapse_types = ('excitatory', 'inhibitory') 
    421     conductance_based = True # over-ride for cells with current-based synapses 
    422     always_local = False # over-ride for NEST spike sources 
    423  
    424  
    425 class ModelNotAvailable(object): 
    426     """Not available for this simulator.""" 
    427      
    428     def __init__(self, *args, **kwargs): 
    429         raise NotImplementedError("The %s model is not available for this simulator." % self.__class__.__name__) 
     247 
    430248 
    431249# ============================================================================== 
     
    519337    # should refactor to eliminate this repetition 
    520338    logger.debug("connecting %s to %s on host %d" % (source, target, rank())) 
    521     if not is_listlike(source): 
     339    if not core.is_listlike(source): 
    522340        source = [source] 
    523     if not is_listlike(target): 
     341    if not core.is_listlike(target): 
    524342        target = [target] 
    525343    delay = check_delay(delay) 
     
    599417          e.g., (10,10) will create a two-dimensional population of size 10x10. 
    600418        cellclass should either be a standardized cell class (a class inheriting 
    601         from common.StandardCellType) or a string giving the name of the 
     419        from common.standardmodels.StandardCellType) or a string giving the name of the 
    602420        simulator-specific model that makes up the population. 
    603421        cellparams should be a dict which is passed to the neuron model 
     
    612430            assert isinstance(dims, tuple), "`dims` must be an integer or a tuple. You have supplied a %s" % type(dims) 
    613431        self.label = label or 'population%d' % Population.nPop          
    614         if isinstance(cellclass, type) and issubclass(cellclass, StandardCellType): 
     432        if isinstance(cellclass, type) and issubclass(cellclass, standardmodels.StandardCellType): 
    615433            self.celltype = cellclass(cellparams) 
    616434        else: 
     
    853671    def can_record(self, variable): 
    854672        """Determine whether `variable` can be recorded from this population.""" 
    855         if isinstance(self.celltype, StandardCellType): 
     673        if isinstance(self.celltype, standardmodels.StandardCellType): 
    856674            return (variable in self.celltype.recordable) 
    857675        else: 
     
    1110928                 connecting the neurons. 
    1111929         
    1112         synapse_dynamics - a `SynapseDynamics` object specifying which 
     930        synapse_dynamics - a `standardmodels.SynapseDynamics` object specifying which 
    1113931                 synaptic plasticity mechanisms to use. 
    1114932         
     
    1139957        self.long_term_plasticity_mechanism = None 
    1140958        if self.synapse_dynamics: 
    1141             assert isinstance(self.synapse_dynamics, SynapseDynamics), \ 
    1142               "The synapse_dynamics argument, if specified, must be a SynapseDynamics object, not a %s" % type(synapse_dynamics) 
     959            assert isinstance(self.synapse_dynamics, standardmodels.SynapseDynamics), \ 
     960              "The synapse_dynamics argument, if specified, must be a standardmodels.SynapseDynamics object, not a %s" % type(synapse_dynamics) 
    1143961            if self.synapse_dynamics.fast: 
    1144                 assert isinstance(self.synapse_dynamics.fast, ShortTermPlasticityMechanism) 
     962                assert isinstance(self.synapse_dynamics.fast, standardmodels.ShortTermPlasticityMechanism) 
    1145963                if hasattr(self.synapse_dynamics.fast, 'native_name'): 
    1146964                    self.short_term_plasticity_mechanism = self.synapse_dynamics.fast.native_name 
     
    1149967                self._short_term_plasticity_parameters = self.synapse_dynamics.fast.parameters.copy() 
    1150968            if self.synapse_dynamics.slow: 
    1151                 assert isinstance(self.synapse_dynamics.slow, STDPMechanism) 
     969                assert isinstance(self.synapse_dynamics.slow, standardmodels.STDPMechanism) 
    1152970                assert 0 <= self.synapse_dynamics.slow.dendritic_delay_fraction <= 1.0 
    1153971                td = self.synapse_dynamics.slow.timing_dependence 
     
    12341052        Set parameters of the synapse dynamics to values taken from rand_distr 
    12351053        """ 
    1236         self.setSynapseDynamics(param, rand_distr.next(len(self))) 
     1054        self.setstandardmodels.SynapseDynamics(param, rand_distr.next(len(self))) 
    12371055     
    12381056    # --- Methods for writing/reading information to/from file. ---------------- 
     
    12691087        """ 
    12701088        if gather: 
    1271             logger.error("getSynapseDynamics() with gather=True not yet implemented") 
     1089            logger.error("getstandardmodels.SynapseDynamics() with gather=True not yet implemented") 
    12721090        return self.connection_manager.get(parameter_name, format, offset=(self.pre.first_id, self.post.first_id)) 
    12731091     
     
    13741192            return Template(template).substitute(context) 
    13751193 
    1376          
    13771194# ============================================================================== 
    1378 #   Synapse Dynamics classes 
    1379 # ============================================================================== 
    1380  
    1381 class SynapseDynamics(object): 
    1382     """ 
    1383     For specifying synapse short-term (faciliation, depression) and long-term 
    1384     (STDP) plasticity. To be passed as the `synapse_dynamics` argument to 
    1385     `Projection.__init__()` or `connect()`. 
    1386     """ 
    1387      
    1388     def __init__(self, fast=None, slow=None): 
    1389         """ 
    1390         Create a new specification for a dynamic synapse, combining a `fast` 
    1391         component (short-term facilitation/depression) and a `slow` component 
    1392         (long-term potentiation/depression). 
    1393         """ 
    1394         self.fast = fast 
    1395         self.slow = slow 
    1396      
    1397     def describe(self, template='standard'): 
    1398         """ 
    1399         Return a human-readable description of the synaptic properties. 
    1400         """ 
    1401         if template == 'standard': 
    1402             lines = ["Short-term plasticity mechanism: $fast", 
    1403                      "Long-term plasticity mechanism: $slow"] 
    1404             template = "\n".join(lines) 
    1405         context = {'fast': self.fast and self.fast.describe() or 'None', 
    1406                    'slow': self.slow and self.slow.describe() or 'None'} 
    1407         if template == None: 
    1408             return context 
    1409         else: 
    1410             return Template(template).substitute(context) 
    1411          
    1412          
    1413 class ShortTermPlasticityMechanism(StandardModelType): 
    1414     """Abstract base class for models of short-term synaptic dynamics.""" 
    1415      
    1416     def __init__(self): 
    1417         raise NotImplementedError 
    1418  
    1419  
    1420 class STDPMechanism(object): 
    1421     """Specification of STDP models.""" 
    1422      
    1423     def __init__(self, timing_dependence=None, weight_dependence=None, 
    1424                  voltage_dependence=None, dendritic_delay_fraction=1.0): 
    1425         """ 
    1426         Create a new specification for an STDP mechanism, by combining a 
    1427         weight-dependence, a timing-dependence, and, optionally, a voltage- 
    1428         dependence. 
    1429          
    1430         For point neurons, the synaptic delay `d` can be interpreted either as 
    1431         occurring purely in the pre-synaptic axon + synaptic cleft, in which 
    1432         case the synaptic plasticity mechanism 'sees' the post-synaptic spike 
    1433         immediately and the pre-synaptic spike after a delay `d` 
    1434         (`dendritic_delay_fraction = 0`) or as occurring purely in the post- 
    1435         synaptic dendrite, in which case the pre-synaptic spike is seen 
    1436         immediately, and the post-synaptic spike after a delay `d` 
    1437         (`dendritic_delay_fraction = 1`), or as having both pre- and post- 
    1438         synaptic components (`dendritic_delay_fraction` between 0 and 1). 
    1439          
    1440         In a future version of the API, we will allow the different 
    1441         components of the synaptic delay to be specified separately in 
    1442         milliseconds. 
    1443         """ 
    1444         self.timing_dependence = timing_dependence 
    1445         self.weight_dependence = weight_dependence 
    1446         self.voltage_dependence = voltage_dependence 
    1447         self.dendritic_delay_fraction = dendritic_delay_fraction 
    1448          
    1449     def describe(self): 
    1450         """ 
    1451         Return a human-readable description of the STDP mechanism. 
    1452         """ 
    1453         return "STDP mechanism (this description needs to be filled out)." 
    1454  
    1455  
    1456 class STDPWeightDependence(StandardModelType): 
    1457     """Abstract base class for models of STDP weight dependence.""" 
    1458      
    1459     def __init__(self): 
    1460         raise NotImplementedError 
    1461  
    1462  
    1463 class STDPTimingDependence(StandardModelType): 
    1464     """Abstract base class for models of STDP timing dependence (triplets, etc)""" 
    1465      
    1466     def __init__(self): 
    1467         raise NotImplementedError 
    1468  
    1469  
    1470 # ==============================================================================