Changeset 1000 for trunk

Show
Ignore:
Timestamp:
10/30/11 15:45:23 (7 months ago)
Author:
apdavison
Message:

No longer store a global reference to the current simulator module in common. Now each class that needs it stores its own reference as the attribute _simulator. This will make it much easier to use more than one simulator at the same time, once Connector classes have been fixed up to use the simulator module only during connect(), not on creation

Location:
trunk
Files:
32 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/brian/__init__.py

    r999 r1000  
    1414from pyNN.brian import simulator 
    1515from pyNN import common, recording, space, core, __doc__ 
    16 common.control.simulator = simulator 
    17 recording.simulator = simulator 
    1816from pyNN.random import * 
    1917from pyNN.recording import files 
     
    8987# ============================================================================== 
    9088 
    91 get_time_step = common.get_time_step 
    92 get_min_delay = common.get_min_delay 
    93 get_max_delay = common.get_max_delay 
    94 num_processes = common.num_processes 
    95 rank = common.rank 
     89get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     90            num_processes, rank = common.control.build_state_queries(simulator) 
    9691 
    9792# ============================================================================== 
     
    9994#   neurons. 
    10095# ============================================================================== 
     96 
     97class Assembly(common.Assembly): 
     98    _simulator = simulator 
     99 
     100 
     101class PopulationView(common.PopulationView): 
     102    _simulator = simulator 
     103    assembly_class = Assembly 
     104 
     105    def _get_view(self, selector, label=None): 
     106        return PopulationView(self, selector, label) 
     107 
    101108 
    102109class Population(common.Population, common.BasePopulation): 
     
    105112    term intended to include layers, columns, nuclei, etc., of cells. 
    106113    """ 
     114    _simulator = simulator 
    107115    recorder_class = Recorder 
     116    assembly_class = Assembly 
     117 
     118    def _get_view(self, selector, label=None): 
     119        return PopulationView(self, selector, label) 
    108120 
    109121    def _create_cells(self, cellclass, cellparams=None, n=1): 
     
    187199 
    188200 
    189 PopulationView = common.PopulationView 
    190 Assembly = common.Assembly 
    191  
    192201class Projection(common.Projection): 
    193202    """ 
     
    196205    parameters of those connections, including of plasticity mechanisms. 
    197206    """ 
     207    _simulator = simulator 
    198208     
    199209    def __init__(self, presynaptic_population, postsynaptic_population, method, 
  • trunk/src/brian/connectors.py

    r998 r1000  
    1010from pyNN.random import RandomDistribution 
    1111import numpy 
    12 from pyNN import random, common, core 
     12from pyNN import core 
    1313from pyNN.connectors import AllToAllConnector, \ 
    1414                            ProbabilisticConnector, \ 
     
    2424                            WeightGenerator, \ 
    2525                            DelayGenerator, \ 
    26                             ProbaGenerator 
     26                            ProbaGenerator, Connector 
     27from pyNN.brian import simulator 
     28 
     29Connector._simulator = simulator 
     30 
    2731 
    2832class FastProbabilisticConnector(ProbabilisticConnector): 
     
    125129                raise Exception('Expression for weights or delays is not supported for OneToOneConnector !') 
    126130            weights_generator = WeightGenerator(self.weights, local, projection, self.safe) 
    127             delays_generator  = DelayGenerator(self.delays, local, self.safe)                 
     131            delays_generator  = DelayGenerator(self.delays, local, kernel=projection._simulator.state, safe=self.safe)                 
    128132            weights           = weights_generator.get(N) 
    129133            delays            = delays_generator.get(N) 
  • trunk/src/brian/recording.py

    r978 r1000  
    2121class Recorder(recording.Recorder): 
    2222    """Encapsulates data and functions related to recording model variables.""" 
     23    _simulator = simulator 
    2324   
    2425    def __init__(self, variable, population=None, file=None): 
  • trunk/src/common/__init__.py

    r999 r1000  
    4949 
    5050from populations import IDMixin, BasePopulation, Population, PopulationView, Assembly, is_conductance 
    51 from projections import Projection, check_weight, check_delay, DEFAULT_WEIGHT 
     51from projections import Projection, check_weight, DEFAULT_WEIGHT 
    5252from procedural_api import build_create, build_connect, set, build_record, initialize 
    53 from control import setup, end, run, reset, get_current_time, get_time_step, \ 
    54                     get_min_delay, get_max_delay, num_processes, rank 
     53from control import setup 
  • trunk/src/common/control.py

    r999 r1000  
    4040    raise NotImplementedError 
    4141 
    42 def reset(): 
    43     """ 
    44     Reset the time to zero, neuron membrane potentials and synaptic weights to 
    45     their initial values, and delete any recorded data. The network structure 
    46     is not changed, nor is the specification of which neurons to record from. 
    47     """ 
    48     simulator.reset() 
     42def build_reset(simulator): 
     43    def reset(): 
     44        """ 
     45        Reset the time to zero, neuron membrane potentials and synaptic weights to 
     46        their initial values, and delete any recorded data. The network structure 
     47        is not changed, nor is the specification of which neurons to record from. 
     48        """ 
     49        simulator.reset() 
     50    return reset 
    4951 
    50 def get_current_time(): 
    51     """Return the current time in the simulation.""" 
    52     return simulator.state.t 
    53  
    54 def get_time_step(): 
    55     """Return the integration time step.""" 
    56     return simulator.state.dt 
    57  
    58 def get_min_delay(): 
    59     """Return the minimum allowed synaptic delay.""" 
    60     return simulator.state.min_delay 
    61  
    62 def get_max_delay(): 
    63     """Return the maximum allowed synaptic delay.""" 
    64     return simulator.state.max_delay 
    65  
    66 def num_processes(): 
    67     """Return the number of MPI processes.""" 
    68     return simulator.state.num_processes 
    69  
    70 def rank(): 
    71     """Return the MPI rank of the current node.""" 
    72     return simulator.state.mpi_rank 
     52def build_state_queries(simulator): 
     53    def get_current_time(): 
     54        """Return the current time in the simulation.""" 
     55        return simulator.state.t 
     56     
     57    def get_time_step(): 
     58        """Return the integration time step.""" 
     59        return simulator.state.dt 
     60     
     61    def get_min_delay(): 
     62        """Return the minimum allowed synaptic delay.""" 
     63        return simulator.state.min_delay 
     64     
     65    def get_max_delay(): 
     66        """Return the maximum allowed synaptic delay.""" 
     67        return simulator.state.max_delay 
     68     
     69    def num_processes(): 
     70        """Return the number of MPI processes.""" 
     71        return simulator.state.num_processes 
     72     
     73    def rank(): 
     74        """Return the MPI rank of the current node.""" 
     75        return simulator.state.mpi_rank 
     76     
     77    return get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank 
  • trunk/src/common/populations.py

    r999 r1000  
    166166            return self.all_cells[index] 
    167167        elif isinstance(index, (slice, list, numpy.ndarray)): 
    168             return PopulationView(self, index) 
     168            return self._get_view(index) 
    169169        elif isinstance(index, tuple): 
    170             return PopulationView(self, list(index)) 
     170            return self._get_view(list(index)) 
    171171        else: 
    172172            raise TypeError("indices must be integers, slices, lists, arrays or tuples, not %s" % type(index).__name__) 
     
    206206        """ 
    207207        assert isinstance(other, BasePopulation) 
    208         return Assembly(self, other) 
     208        return self.assembly_class(self, other) 
    209209 
    210210    def _get_cell_position(self, id): 
     
    248248        logger.debug("The %d cells recorded have indices %s" % (n, indices)) 
    249249        logger.debug("%s.sample(%s)", self.label, n) 
    250         return PopulationView(self, indices) 
     250        return self._get_view(indices) 
    251251 
    252252    def get(self, parameter_name, gather=False): 
     
    262262            values = [getattr(cell, parameter_name) for cell in self]  # list or array? 
    263263         
    264         if gather == True and control.num_processes() > 1: 
    265             all_values  = { control.rank(): values } 
    266             all_indices = { control.rank(): self.local_cells.tolist()} 
     264        if gather == True and self._simulator.state.num_processes > 1: 
     265            all_values  = { self._simulator.state.mpi_rank: values } 
     266            all_indices = { self._simulator.state.mpi_rank: self.local_cells.tolist()} 
    267267            all_values  = recording.gather_dict(all_values) 
    268268            all_indices = recording.gather_dict(all_indices) 
    269             if control.rank() == 0: 
     269            if self._simulator.state.mpi_rank == 0: 
    270270                values  = reduce(operator.add, all_values.values()) 
    271271                indices = reduce(operator.add, all_indices.values()) 
     
    582582        spike_counts = self.recorders['spikes'].count(gather, self.record_filter) 
    583583        total_spikes = sum(spike_counts.values()) 
    584         if control.rank() == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes 
     584        if self._simulator.state.mpi_rank == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes 
    585585            if len(spike_counts) > 0: 
    586586                return float(total_spikes)/len(spike_counts) 
     
    610610        result[:,0]   = cells 
    611611        result[:,1:4] = self.positions.T  
    612         if control.rank() == 0: 
     612        if self._simulator.state.mpi_rank == 0: 
    613613            file.write(result, {'population' : self.label}) 
    614614            file.close() 
     
    707707        (order in the Population), counting only cells on the local MPI node. 
    708708        """ 
    709         if control.num_processes() > 1: 
     709        if self._simulator.state.num_processes > 1: 
    710710            return self.local_cells.tolist().index(id)          # probably very slow 
    711711            #return numpy.nonzero(self.local_cells == id)[0][0] # possibly faster? 
     
    10981098            pindices = boundaries[1:].searchsorted(indices, side='right') 
    10991099            views = (self.populations[i][indices[pindices==i] - boundaries[i]] for i in numpy.unique(pindices)) 
    1100             return Assembly(*views) 
     1100            return self.__class__(*views) 
    11011101        else: 
    11021102            raise TypeError("indices must be integers, slices, lists, arrays, not %s" % type(index).__name__) 
     
    11101110        """ 
    11111111        if isinstance(other, BasePopulation): 
    1112             return Assembly(*(self.populations + [other])) 
     1112            return self.__class__(*(self.populations + [other])) 
    11131113        elif isinstance(other, Assembly): 
    1114             return Assembly(*(self.populations + other.populations)) 
     1114            return self.__class__(*(self.populations + other.populations)) 
    11151115        else: 
    11161116            raise TypeError("can only add a Population or another Assembly to an Assembly") 
     
    12031203        result[:,0]   = cells 
    12041204        result[:,1:4] = self.positions.T  
    1205         if control.rank() == 0: 
     1205        if self._simulator.state.mpi_rank == 0: 
    12061206            file.write(result, {'assembly' : self.label}) 
    12071207            file.close() 
     
    12491249        spike_counts = self.get_spike_counts() 
    12501250        total_spikes = sum(spike_counts.values()) 
    1251         if control.rank() == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes 
     1251        if self._simulator.state.mpi_rank() == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes 
    12521252            return float(total_spikes)/len(spike_counts) 
    12531253        else: 
     
    13011301                    'last_id'     : self.last_id} 
    13021302         
    1303         metadata['dt'] = control.simulator.state.dt # note that this has to run on all nodes (at least for NEST) 
     1303        metadata['dt'] = self._simulator.state.dt # note that this has to run on all nodes (at least for NEST) 
    13041304        data = numpy.zeros(format) 
    13051305        for pop in filenames.keys(): 
    13061306            if filenames[pop][1] is True: 
    13071307                name     = filenames[pop][0] 
    1308                 if gather==False and control.simulator.state.num_processes > 1: 
    1309                     name += '.%d' % control.simulator.state.mpi_rank             
     1308                if gather==False and self._simulator.state.num_processes > 1: 
     1309                    name += '.%d' % self._simulator.state.mpi_rank             
    13101310                p_file   = files.NumpyBinaryFile(name, mode='r')  
    13111311                tmp_data = p_file.read()                     
     
    13181318         
    13191319        if isinstance(file, basestring): 
    1320             if gather==False and control.simulator.state.num_processes > 1: 
    1321                 file += '.%d' % control.simulator.state.mpi_rank 
     1320            if gather==False and self._simulator.state.num_processes > 1: 
     1321                file += '.%d' % self._simulator.state.mpi_rank 
    13221322            file = files.StandardTextFile(file, mode='w') 
    13231323         
    1324         if control.simulator.state.mpi_rank == 0 or gather == False: 
     1324        if self._simulator.state.mpi_rank == 0 or gather == False: 
    13251325            file.write(data, metadata) 
    13261326            file.close() 
  • trunk/src/common/projections.py

    r999 r1000  
    1313from pyNN.recording import files 
    1414from populations import BasePopulation, Assembly, is_conductance 
    15 from control import get_min_delay, get_max_delay, rank, num_processes 
    1615 
    1716logger = logging.getLogger("PyNN") 
     
    4645 
    4746 
    48 def check_delay(delay): 
    49     if delay is None: 
    50         delay = get_min_delay() 
    51     # If the delay is too small , we have to throw an error 
    52     if delay < get_min_delay() or delay > get_max_delay(): 
    53         raise errors.ConnectionError("delay (%s) is out of range [%s,%s]" % \ 
    54                                      (delay, get_min_delay(), get_max_delay())) 
    55     return delay 
    56  
    57  
    5847class Projection(object): 
    5948    """ 
     
    301290            all_lines = { rank(): lines } 
    302291            all_lines = recording.gather_dict(all_lines) 
    303             if rank() == 0: 
     292            if self._simulator.state.mpi_rank == 0: 
    304293                lines = reduce(operator.add, all_lines.values()) 
    305         elif num_processes() > 1: 
    306             file.rename('%s.%d' % (file.name, rank())) 
     294        elif self._simulator.state.num_processes > 1: 
     295            file.rename('%s.%d' % (file.name, self._simulator.state.mpi_rank)) 
    307296         
    308297        logger.debug("--- Projection[%s].__saveConnections__() ---" % self.label) 
  • trunk/src/connectors.py

    r999 r1000  
    178178    """Generator for synaptic delays. %s""" % ConnectionAttributeGenerator.__doc__ 
    179179 
    180     def __init__(self, source, local_mask, safe=True): 
     180    def __init__(self, source, local_mask, kernel, safe=True): 
    181181        ConnectionAttributeGenerator.__init__(self, source, local_mask, safe) 
    182         self.min_delay = common.get_min_delay() 
    183         self.max_delay = common.get_max_delay() 
     182        assert hasattr(kernel, "min_delay") 
     183        self.kernel = kernel 
    184184         
    185185    def check(self, delay): 
    186         all_negative = (delay<=self.max_delay).all() 
    187         all_positive = (delay>=self.min_delay).all()# If the delay is too small , we have to throw an error 
     186        min_delay = self.kernel.min_delay 
     187        max_delay = self.kernel.max_delay 
     188        all_negative = (delay<=max_delay).all() 
     189        all_positive = (delay>=min_delay).all()# If the delay is too small , we have to throw an error 
    188190        if not (all_negative and all_positive): 
    189             raise errors.ConnectionError("delay (%s) is out of range [%s,%s]" % (delay, common.get_min_delay(), common.get_max_delay())) 
     191            raise errors.ConnectionError("delay (%s) is out of range [%s,%s]" % (delay, min_delay, max_delay)) 
    190192        return delay     
    191193 
     
    232234        self.safe    = safe 
    233235        self.verbose = verbose 
    234         min_delay    = common.get_min_delay() 
     236        min_delay    = self._simulator.state.min_delay 
    235237        if delays is None: 
    236238            self.delays = min_delay 
     
    252254    def progression(self, count): 
    253255        self.prog.update_amount(count) 
    254         if self.verbose and common.rank() == 0:            
     256        if self.verbose and self._simulator.state.mpi_rank == 0:            
    255257            print self.prog, "\r", 
    256258            sys.stdout.flush() 
     
    293295        self.N                 = projection.post.size 
    294296        self.weights_generator = WeightGenerator(weights, self.local, projection, safe) 
    295         self.delays_generator  = DelayGenerator(delays, self.local, safe) 
     297        self.delays_generator  = DelayGenerator(delays, self.local, kernel=projection._simulator.state, safe=safe) 
    296298        self.probas_generator  = ProbaGenerator(RandomDistribution('uniform', (0,1), rng=self.rng), self.local) 
    297299        self._distance_matrix  = None 
     
    468470        proba_generator = ProbaGenerator(self.d_expression, connector.local) 
    469471        self.progressbar(len(projection.pre)) 
    470         if (common.num_processes() > 1) and (self.n_connections is not None): 
     472        if (projection._simulator.state.num_processes > 1) and (self.n_connections is not None): 
    471473            raise Exception("n_connections not implemented yet for this connector in parallel !") 
    472474 
     
    498500        """ 
    499501        # needs extending for dynamic synapses. 
    500         Connector.__init__(self, 0., common.get_min_delay(), safe=safe, verbose=verbose) 
     502        Connector.__init__(self, 0.0, self._simulator.state.min_delay, safe=safe, verbose=verbose) 
    501503        self.conn_list  = numpy.array(conn_list)                
    502504         
     
    550552                         distributed simulations. 
    551553        """ 
    552         Connector.__init__(self, 0., common.get_min_delay(), safe=safe, verbose=verbose) 
     554        Connector.__init__(self, 0.0, self._simulator.state.min_delay, safe=safe, verbose=verbose) 
    553555         
    554556        if isinstance(file, basestring): 
     
    560562        """Connect-up a Projection.""" 
    561563        if self.distributed: 
    562             self.file.rename("%s.%d" % (self.file.name, common.rank()))         
     564            self.file.rename("%s.%d" % (self.file.name, projection._simulator.state.mpi_rank))         
    563565        self.conn_list = self.file.read() 
    564566        FromListConnector.connect(self, projection) 
     
    614616        local             = numpy.ones(len(projection.post), bool) 
    615617        weights_generator = WeightGenerator(self.weights, local, projection, self.safe) 
    616         delays_generator  = DelayGenerator(self.delays, local, self.safe) 
     618        delays_generator  = DelayGenerator(self.delays, local, kernel=projection._simulator.state, safe=self.safe) 
    617619        distance_matrix   = DistanceMatrix(projection.post.positions, self.space) 
    618620        candidates        = projection.post.all_cells 
     
    700702        local             = numpy.ones(len(projection.pre), bool) 
    701703        weights_generator = WeightGenerator(self.weights, local, projection, self.safe) 
    702         delays_generator  = DelayGenerator(self.delays, local, self.safe) 
     704        delays_generator  = DelayGenerator(self.delays, local, kernel=projection._simulator.state, safe=self.safe) 
    703705        distance_matrix   = DistanceMatrix(projection.pre.positions, self.space)               
    704706        candidates        = projection.pre.all_cells  
     
    767769                raise Exception('Expression for weights or delays is not supported for OneToOneConnector !') 
    768770            weights_generator = WeightGenerator(self.weights, local, projection, self.safe) 
    769             delays_generator  = DelayGenerator(self.delays, local, self.safe)                 
     771            delays_generator  = DelayGenerator(self.delays, local, kernel=projection._simulator.state, safe=self.safe)                 
    770772            weights           = weights_generator.get(N) 
    771773            delays            = delays_generator.get(N) 
     
    871873            self.rng = projection.rng 
    872874        self.weights_generator = WeightGenerator(self.weights, local, projection, self.safe) 
    873         self.delays_generator  = DelayGenerator(self.delays, local, self.safe) 
     875        self.delays_generator  = DelayGenerator(self.delays, local, kernel=projection._simulator.state, safe=self.safe) 
    874876        self.probas_generator  = ProbaGenerator(RandomDistribution('uniform',(0,1), rng=self.rng), local) 
    875877        self.distance_matrix   = DistanceMatrix(projection.post.positions, self.space, local) 
     
    893895            """ 
    894896            """ 
    895             min_delay = common.get_min_delay() 
     897            min_delay = self._simulator.state.min_delay 
    896898            Connector.__init__(self, None, None, safe=safe, verbose=verbose) 
    897899            self.cset = cset 
     
    904906                self.delays = delays 
    905907                if delays is None: 
    906                     self.delays = common.get_min_delay() 
     908                    self.delays = self._simulator.state.min_delay 
    907909            else: 
    908910                assert cset.arity == 2, 'must specify mask or connection-set with arity 2' 
  • trunk/src/moose/__init__.py

    r999 r1000  
    1717from pyNN.moose import simulator 
    1818from pyNN import common, recording, core 
    19 common.control.simulator = simulator 
    20 recording.simulator = simulator 
    2119 
    2220from pyNN.connectors import FixedProbabilityConnector, AllToAllConnector, OneToOneConnector 
     
    5856    return get_current_time() 
    5957 
    60 reset = common.reset 
     58reset = common.control.build_reset(simulator) 
    6159 
    6260initialize = common.initialize 
     
    6664# ============================================================================== 
    6765 
    68 get_current_time = common.get_current_time 
    69 get_time_step = common.get_time_step 
    70 get_min_delay = common.get_min_delay 
    71 get_max_delay = common.get_max_delay 
    72 num_processes = common.num_processes 
    73 rank = common.rank 
     66get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     67            num_processes, rank = common.control.build_state_queries(simulator) 
    7468 
    7569# ============================================================================== 
     
    7872# ============================================================================== 
    7973 
     74class Assembly(common.Assembly): 
     75    _simulator = simulator 
     76 
     77 
     78class PopulationView(common.PopulationView): 
     79    _simulator = simulator 
     80    assembly_class = Assembly 
     81 
     82    def _get_view(self, selector, label=None): 
     83        return PopulationView(self, selector, label) 
     84 
     85 
    8086class Population(common.Population): 
    8187    """ 
     
    8389    term intended to include layers, columns, nuclei, etc., of cells. 
    8490    """ 
     91    _simulator = simulator 
    8592    recorder_class = Recorder 
     93    assembly_class = Assembly 
     94 
     95    def _get_view(self, selector, label=None): 
     96        return PopulationView(self, selector, label) 
    8697 
    8798    def _create_cells(self, cellclass, cellparams, n): 
     
    118129    parameters of those connections, including of plasticity mechanisms. 
    119130    """ 
    120  
     131    _simulator = simulator 
    121132    nProj = 0 
    122133 
  • trunk/src/multisim.py

    r999 r1000  
    88 
    99from multiprocessing import Process, Queue 
    10 from pyNN import common, recording 
    1110 
    1211def run_simulation(network_model, sim, parameters, input_queue, output_queue): 
     
    1716    """ 
    1817    print "Running simulation with %s" % sim.__name__ 
    19     common.control.simulator = sim.simulator 
    20     recording.simulator = sim.simulator 
    2118    network = network_model(sim, parameters) 
    2219    print "Network constructed with %s." % sim.__name__ 
  • trunk/src/nemo/__init__.py

    r999 r1000  
    1414from pyNN.nemo import simulator 
    1515from pyNN import common, recording, space, core, __doc__ 
    16 common.control.simulator = simulator 
    17 recording.simulator = simulator 
    1816from pyNN.random import * 
    1917from pyNN.recording import files 
     
    5351    simulator.spikes_array_list = [] 
    5452    simulator.recorder_lise     = [] 
    55     return rank() 
     53    return simulator.state.mpi_rank 
    5654 
    5755def end(compatible_output=True): 
     
    6462    electrodes.current_sources = [] 
    6563 
    66 def get_current_time(): 
    67     """Return the current time in the simulation.""" 
    68     return simulator.state.t 
    6964     
    7065def run(simtime):     
    7166    """Run the simulation for simtime ms.""" 
    7267    simulator.state.run(simtime) 
    73     return get_current_time() 
     68    return simulator.state.t 
    7469 
    7570reset      = simulator.reset 
     
    8075# ============================================================================== 
    8176 
    82 get_time_step = common.get_time_step 
    83 get_min_delay = common.get_min_delay 
    84 get_max_delay = common.get_max_delay 
    85 num_processes = common.num_processes 
    86 rank = common.rank 
     77get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     78            num_processes, rank = common.control.build_state_queries(simulator) 
    8779 
    8880# ============================================================================== 
     
    9183# ============================================================================== 
    9284 
     85class Assembly(common.Assembly): 
     86    _simulator = simulator 
     87 
     88 
     89class PopulationView(common.PopulationView): 
     90    _simulator = simulator 
     91    assembly_class = Assembly 
     92     
     93    def _get_view(self, selector, label=None): 
     94        return PopulationView(self, selector, label) 
     95 
     96 
    9397class Population(common.Population, common.BasePopulation): 
    9498    """ 
     
    96100    term intended to include layers, columns, nuclei, etc., of cells. 
    97101    """ 
     102    _simulator = simulator 
    98103    recorder_class = Recorder 
     104    assembly_class = Assembly 
     105 
     106    def _get_view(self, selector, label=None): 
     107        return PopulationView(self, selector, label) 
    99108 
    100109    def _create_cells(self, cellclass, cellparams=None, n=1): 
     
    135144    def _set_initial_value_array(self, variable, value): 
    136145        if not hasattr(value, "__len__"): 
    137             value = value*numpy.ones((len(self),))         
    138         
    139 PopulationView = common.PopulationView 
    140 Assembly = common.Assembly 
     146            value = value*numpy.ones((len(self),)) 
     147 
    141148 
    142149class Projection(common.Projection): 
     
    146153    parameters of those connections, including of plasticity mechanisms. 
    147154    """ 
     155    _simulator = simulator 
    148156     
    149157    def __init__(self, presynaptic_population, postsynaptic_population, method, 
  • trunk/src/nemo/connectors.py

    r957 r1000  
    2424                            WeightGenerator, \ 
    2525                            DelayGenerator, \ 
    26                             ProbaGenerator 
     26                            ProbaGenerator, \ 
     27                            Connector 
     28from pyNN.nemo import simulator 
     29 
     30Connector._simulator = simulator 
  • trunk/src/nemo/recording.py

    r957 r1000  
    1616class Recorder(recording.Recorder): 
    1717    """Encapsulates data and functions related to recording model variables.""" 
     18    _simulator = simulator 
    1819   
    1920    def __init__(self, variable, population=None, file=None): 
  • trunk/src/nest/__init__.py

    r999 r1000  
    1111from pyNN.nest import simulator 
    1212from pyNN import common, recording, errors, space, __doc__ 
    13 common.control.simulator = simulator 
    14 recording.simulator = simulator 
    1513 
    1614if recording.MPI and (nest.Rank() != recording.mpi_comm.rank): 
     
    154152    return get_current_time() 
    155153 
    156 reset = common.reset 
     154reset = common.control.build_reset(simulator) 
    157155 
    158156initialize = common.initialize 
     
    162160# ============================================================================== 
    163161 
    164 get_current_time = common.get_current_time 
    165 get_time_step = common.get_time_step 
    166 get_min_delay = common.get_min_delay 
    167 get_max_delay = common.get_max_delay 
    168 num_processes = common.num_processes 
    169 rank = common.rank 
     162get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     163            num_processes, rank = common.control.build_state_queries(simulator) 
     164 
    170165 
    171166# ============================================================================== 
     
    173168#   neurons. 
    174169# ============================================================================== 
     170 
     171class Assembly(common.Assembly): 
     172    _simulator = simulator 
     173 
     174 
     175class PopulationView(common.PopulationView): 
     176    _simulator = simulator 
     177    assembly_class = Assembly 
     178 
     179    def _get_view(self, selector, label=None): 
     180        return PopulationView(self, selector, label) 
     181 
    175182 
    176183class Population(common.Population): 
     
    179186    term intended to include layers, columns, nuclei, etc., of cells. 
    180187    """ 
     188    _simulator = simulator 
    181189    recorder_class = Recorder 
     190    assembly_class = Assembly 
     191 
     192    def _get_view(self, selector, label=None): 
     193        return PopulationView(self, selector, label) 
    182194 
    183195    def _create_cells(self, cellclass, cellparams, n): 
     
    280292        nest.SetStatus(self.local_cells.tolist(), variable, value) 
    281293 
    282 PopulationView = common.PopulationView 
    283 Assembly = common.Assembly 
    284  
    285294 
    286295class Projection(common.Projection): 
     
    290299    parameters of those connections, including of plasticity mechanisms. 
    291300    """ 
    292  
     301    _simulator = simulator 
    293302    nProj = 0 
    294303 
  • trunk/src/nest/connectors.py

    r998 r1000  
    1313                            FromListConnector, FromFileConnector, WeightGenerator, \ 
    1414                            DelayGenerator, ProbaGenerator, DistanceMatrix, CSAConnector 
    15 from pyNN.common import rank, num_processes 
     15from pyNN.nest import simulator 
    1616import numpy 
    1717from pyNN.space import Space 
     18 
     19Connector._simulator = simulator 
    1820 
    1921 
     
    3537        self.local_long[idx]   = True 
    3638        self.weights_generator = WeightGenerator(weights, self.local_long, projection, safe) 
    37         self.delays_generator  = DelayGenerator(delays, self.local_long, safe) 
     39        self.delays_generator  = DelayGenerator(delays, self.local_long, kernel=projection._simulator.state, safe=safe) 
    3840        self.probas_generator  = ProbaGenerator(random.RandomDistribution('uniform',(0,1), rng=self.rng), self.local_long) 
    3941        self.distance_matrix   = DistanceMatrix(projection.pre.positions, self.space, self.local) 
  • trunk/src/nest/recording.py

    r991 r1000  
    55import warnings 
    66import nest 
    7 from pyNN import recording, common, errors 
     7from pyNN import recording, errors 
    88from pyNN.nest import simulator 
    99 
     
    3131        device_parameters = {"withgid": True, "withtime": True} 
    3232        if self.type is 'multimeter': 
    33             device_parameters["interval"] = common.get_time_step() 
     33            device_parameters["interval"] = simulator.state.dt 
    3434        else: 
    3535            device_parameters["precise_times"] = True 
     
    237237class Recorder(recording.Recorder): 
    238238    """Encapsulates data and functions related to recording model variables.""" 
    239      
     239    _simulator = simulator 
    240240    scale_factors = {'spikes': 1, 
    241241                     'v': 1, 
  • trunk/src/neuron/__init__.py

    r999 r1000  
    1212from pyNN.random import * 
    1313from pyNN.neuron import simulator 
    14 from pyNN import common, core, recording as base_recording, space, __doc__ 
    15 common.control.simulator = simulator 
    16 base_recording.simulator = simulator 
     14from pyNN import common, core, space, __doc__ 
    1715 
    1816from pyNN.neuron.standardmodels.cells import * 
     
    8381    return get_current_time() 
    8482     
    85 reset = common.reset 
     83reset = common.control.build_reset(simulator) 
    8684 
    8785initialize = common.initialize 
     
    9189# ============================================================================== 
    9290 
    93 get_current_time = common.get_current_time 
    94 get_time_step = common.get_time_step 
    95 get_min_delay = common.get_min_delay 
    96 get_max_delay = common.get_max_delay 
    97 num_processes = common.num_processes 
    98 rank = common.rank 
     91get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     92            num_processes, rank = common.control.build_state_queries(simulator) 
    9993 
    10094 
     
    10498# ============================================================================== 
    10599 
     100class Assembly(common.Assembly): 
     101    _simulator = simulator 
     102 
     103 
     104class PopulationView(common.PopulationView): 
     105    _simulator = simulator 
     106    assembly_class = Assembly 
     107     
     108    def _get_view(self, selector, label=None): 
     109        return PopulationView(self, selector, label) 
     110     
     111 
    106112class Population(common.Population): 
    107113    """ 
     
    109115    term intended to include layers, columns, nuclei, etc., of cells. 
    110116    """ 
     117    _simulator = simulator 
    111118    recorder_class = Recorder 
     119    assembly_class = Assembly 
    112120     
    113121    def __init__(self, size, cellclass, cellparams=None, structure=None, 
     
    116124        common.Population.__init__(self, size, cellclass, cellparams, structure, label) 
    117125        simulator.initializer.register(self) 
     126 
     127    def _get_view(self, selector, label=None): 
     128        return PopulationView(self, selector, label) 
    118129 
    119130    def _create_cells(self, cellclass, cellparams, n): 
     
    157168 
    158169 
    159 PopulationView = common.PopulationView 
    160 Assembly = common.Assembly 
    161  
    162170class Projection(common.Projection): 
    163171    """ 
     
    166174    parameters of those connections, including of plasticity mechanisms. 
    167175    """ 
    168      
     176    _simulator = simulator 
    169177    nProj = 0 
    170178     
  • trunk/src/neuron/connectors.py

    r957 r1000  
    88""" 
    99 
     10from pyNN.neuron import simulator 
    1011from pyNN.connectors import AllToAllConnector, \ 
    1112                            OneToOneConnector, \ 
     
    1718                            FixedNumberPostConnector, \ 
    1819                            SmallWorldConnector, \ 
    19                             CSAConnector 
     20                            CSAConnector, \ 
     21                            Connector 
    2022 
    21  
    22  
    23  
     23Connector._simulator = simulator 
  • trunk/src/neuron/recording.py

    r957 r1000  
    1717class Recorder(recording.Recorder): 
    1818    """Encapsulates data and functions related to recording model variables.""" 
    19          
     19    _simulator = simulator 
     20     
    2021    def _record(self, new_ids): 
    2122        """Add the cells in `new_ids` to the set of recorded cells.""" 
  • trunk/src/pcsim/__init__.py

    r999 r1000  
    2222from pyNN import common, recording, errors, space, core, __doc__ 
    2323from pyNN.pcsim import simulator 
    24 common.control.simulator = simulator 
    25 recording.simulator = simulator 
    2624import os.path 
    2725import types 
     
    206204    return simulator.state.t 
    207205 
    208 reset = common.reset 
     206reset = common.control.build_reset(simulator) 
    209207 
    210208initialize = common.initialize 
    211209 
    212 get_current_time = common.get_current_time 
    213 get_time_step = common.get_time_step 
    214 get_min_delay = common.get_min_delay 
    215 get_max_delay = common.get_max_delay 
    216 num_processes = common.num_processes 
    217 rank = common.rank 
     210get_current_time, get_time_step, get_min_delay, get_max_delay, \ 
     211            num_processes, rank = common.control.build_state_queries(simulator) 
    218212 
    219213 
     
    222216#   neurons. 
    223217# ============================================================================== 
     218 
     219class Assembly(common.Assembly): 
     220    _simulator = simulator 
     221     
     222     
     223class PopulationView(common.PopulationView): 
     224    _simulator = simulator 
     225    assembly_class = Assembly 
     226     
     227    def _get_view(self, selector, label=None): 
     228        return PopulationView(self, selector, label) 
     229 
    224230 
    225231class Population(common.Population): 
     
    228234    term intended to include layers, columns, nuclei, etc., of cells. 
    229235    """ 
     236    _simulator = simulator 
    230237    recorder_class = Recorder 
     238    assembly_class = Assembly 
    231239     
    232240    def __init__(self, size, cellclass, cellparams=None, structure=None, 
     
    234242        __doc__ = common.Population.__doc__ 
    235243        common.Population.__init__(self, size, cellclass, cellparams, structure, label) 
     244     
     245    def _get_view(self, selector, label=None): 
     246        return PopulationView(self, selector, label) 
    236247     
    237248    def _create_cells(self, cellclass, cellparams, n): 
     
    406417            obj = simulator.net.object(self.pcsim_population[i]) 
    407418            if obj: apply( obj, methodname, (), arguments) 
    408          
    409 PopulationView = common.PopulationView 
    410 Assembly = common.Assembly 
     419 
    411420 
    412421class Projection(common.Projection, WDManager): 
     
    416425    parameters of those connections, including of plasticity mechanisms. 
    417426    """ 
    418      
     427    _simulator = simulator 
    419428    nProj = 0 
    420429    synapse_target_ids = { 'excitatory': 1, 'inhibitory': 2 } 
  • trunk/src/pcsim/connectors.py

    r957 r1000  
    1616                            FixedNumberPreConnector, \ 
    1717                            FixedNumberPostConnector, \ 
    18                             SmallWorldConnector 
     18                            SmallWorldConnector, \ 
     19                            Connector 
     20from pyNN.pcsim import simulator 
    1921 
    2022 
    21  
    22  
     23Connector._simulator = simulator 
  • trunk/src/pcsim/recording.py

    r957 r1000  
    1313class Recorder(recording.Recorder): 
    1414    """Encapsulates data and functions related to recording model variables.""" 
     15    _simulator = simulator 
    1516     
    1617    fieldnames = {'v': 'Vm', 
  • trunk/src/recording/__init__.py

    r957 r1000  
    144144            filename = file 
    145145            #rename_existing(filename) 
    146             if gather==False and simulator.state.num_processes > 1: 
    147                 filename += '.%d' % simulator.state.mpi_rank 
     146            if gather==False and self._simulator.state.num_processes > 1: 
     147                filename += '.%d' % self._simulator.state.mpi_rank 
    148148        else: 
    149149            filename = file.name 
     
    155155        metadata = self.metadata 
    156156        logger.debug("data has size %s" % str(data.size)) 
    157         if simulator.state.mpi_rank == 0 or gather == False: 
     157        if self._simulator.state.mpi_rank == 0 or gather == False: 
    158158            if compatible_output: 
    159159                data = self._make_compatible(data) 
     
    178178                'label': self.population.label, 
    179179            }) 
    180         metadata['dt'] = simulator.state.dt # note that this has to run on all nodes (at least for NEST) 
     180        metadata['dt'] = self._simulator.state.dt # note that this has to run on all nodes (at least for NEST) 
    181181        if not hasattr(self, '_data_size'): 
    182182            self.get() 
     
    229229        else: 
    230230            raise Exception("Only implemented for spikes.") 
    231         if gather and simulator.state.num_processes > 1: 
     231        if gather and self._simulator.state.num_processes > 1: 
    232232            N = gather_dict(N) 
    233233        return N 
  • trunk/test/system/scenarios.py

    r998 r1000  
    55import numpy 
    66from pyNN.utility import init_logging, assert_arrays_equal, assert_arrays_almost_equal, sort_by_column 
    7  
    8  
    9 def set_simulator(sim): 
    10     common.simulator = sim.simulator 
    11     recording.simulator = sim.simulator 
    127 
    138 
     
    2823    Balanced network of integrate-and-fire neurons. 
    2924    """ 
    30     set_simulator(sim) 
    3125    cell_params = { 
    3226        'tau_m': 20.0, 'tau_syn_E': 5.0, 'tau_syn_I': 10.0, 'v_rest': -60.0, 
     
    9589    API. 
    9690    """ 
    97     set_simulator(sim) 
    9891    cell_params = { 
    9992        'tau_m': 10.0, 'tau_syn_E': 2.0, 'tau_syn_I': 5.0, 'v_rest': -60.0, 
     
    155148    once (except neuron[0], which never reaches threshold). 
    156149    """ 
    157     set_simulator(sim) 
    158150    n = 100 
    159151    t_start = 25.0 
     
    203195    connections should be potentiated more. 
    204196    """ 
    205     set_simulator(sim) 
    206197 
    207198    init_logging(logfile=None, debug=True) 
     
    301292        pylab.rcParams['interactive'] = interactive 
    302293     
    303     set_simulator(sim) 
    304294    sim.setup(timestep=dt) 
    305295     
     
    410400@register(exclude=['pcsim', 'moose']) 
    411401def test_EIF_cond_alpha_isfa_ista(sim): 
    412     set_simulator(sim) 
    413402    sim.setup(timestep=0.01, min_delay=0.1, max_delay=4.0) 
    414403    ifcell = sim.create(sim.EIF_cond_alpha_isfa_ista, 
     
    452441    from pyNN.utility import init_logging 
    453442    init_logging(logfile=None, debug=True) 
    454     set_simulator(sim) 
    455443    dt = 0.1 
    456444    tstop = 100.0 
  • trunk/test/unittests/test_assembly.py

    r999 r1000  
    77     
    88 
     9class MockSimulator(object): 
     10    class MockState(object): 
     11        mpi_rank = 1 
     12        num_processes = 2 
     13    state = MockState() 
     14 
    915def test_create_with_zero_populations(): 
    1016    a = Assembly() 
     
    1319 
    1420class MockPopulation(BasePopulation): 
     21    _simulator = MockSimulator 
    1522    size = 10 
    16     local_cells = numpy.arange(1,10,2) 
    17     all_cells = numpy.arange(10) 
    18     _mask_local = numpy.arange(10)%2 == 1 
     23    local_cells = numpy.arange(_simulator.state.mpi_rank,10,_simulator.state.num_processes) 
     24    all_cells = numpy.arange(size) 
     25    _mask_local = numpy.arange(size)%_simulator.state.num_processes == _simulator.state.mpi_rank 
    1926    initialize = Mock() 
    2027    positions = numpy.arange(3*size).reshape(3,size) 
     
    179186def test_save_positions(): 
    180187    import os 
    181     orig_rank = common.rank 
    182     common.rank = lambda: 0 
     188    Assembly._simulator = MockSimulator 
     189    Assembly._simulator.state.mpi_rank = 0 
    183190    p1 = MockPopulation() 
    184191    p2 = MockPopulation() 
     
    194201    assert_equal(output_file.write.call_args[0][1], {'assembly': a.label}) 
    195202    # arguably, the first column should contain indices, not ids. 
    196     common.rank = orig_rank 
     203    del Assembly._simulator 
  • trunk/test/unittests/test_basepopulation.py

    r999 r1000  
    11from pyNN import common, errors, random, standardmodels, recording 
    2 from pyNN.common import populations, control 
     2from pyNN.common import populations 
    33from nose.tools import assert_equal, assert_raises 
    44import numpy 
     
    99builtin_open = open 
    1010id_map = {'larry': 0, 'curly': 1, 'moe': 2, 'joe': 3, 'william': 4, 'jack': 5, 'averell': 6} 
     11 
     12 
     13class MockSimulator(object): 
     14    class MockState(object): 
     15        mpi_rank = 1 
     16        num_processes = 3 
     17    state = MockState() 
    1118 
    1219class MockStandardCell(standardmodels.StandardCellType): 
     
    1926 
    2027class MockPopulation(populations.BasePopulation): 
     28    _simulator = MockSimulator 
    2129    size = 13 
    2230    all_cells = numpy.arange(100, 113) 
     
    2735    celltype = MockStandardCell({}) 
    2836    initial_values = {"foo": core.LazyArray(numpy.array((98, 100, 102)), shape=(3,))} 
     37    assembly_class = populations.Assembly 
     38 
     39    def _get_view(self, selector, label=None): 
     40        return populations.PopulationView(self, selector, label) 
    2941 
    3042    def id_to_index(self, id): 
     
    6173    p = MockPopulation() 
    6274    pv = p[3:9] 
    63     populations.PopulationView.assert_called_with(p, slice(3,9,None)) 
     75    populations.PopulationView.assert_called_with(p, slice(3,9,None), None) 
    6476    populations.PopulationView = orig_PV 
    6577 
     
    6981    p = MockPopulation() 
    7082    pv = p[range(3,9)] 
    71     populations.PopulationView.assert_called_with(p, range(3,9)) 
     83    populations.PopulationView.assert_called_with(p, range(3,9), None) 
    7284    populations.PopulationView = orig_PV 
    7385 
     
    7789    p = MockPopulation() 
    7890    pv = p[(3,5,7)] 
    79     populations.PopulationView.assert_called_with(p, [3,5,7]) 
     91    populations.PopulationView.assert_called_with(p, [3,5,7], None) 
    8092    populations.PopulationView = orig_PV 
    8193 
     
    175187 
    176188def test_get_with_gather(): 
    177     np_orig = control.num_processes 
    178     rank_orig = control.rank 
     189    np_orig = MockPopulation._simulator.state.num_processes 
     190    rank_orig = MockPopulation._simulator.state.mpi_rank 
    179191    gd_orig = recording.gather_dict 
    180     control.num_processes = lambda: 2 
    181     control.rank = lambda: 0 
     192    MockPopulation._simulator.state.num_processes = 2 
     193    MockPopulation._simulator.state.mpi_rank = 0 
    182194    def mock_gather_dict(D): # really hacky 
    183195        assert isinstance(D[0], list) 
     
    191203                        numpy.arange(10.0, 23.0)) 
    192204     
    193     control.num_processes = np_orig 
    194     control.rank = rank_orig 
     205    MockPopulation._simulator.state.num_processes = np_orig 
     206    MockPopulation._simulator.state.mpi_rank = rank_orig 
    195207    recording.gather_dict = gd_orig 
    196208 
     
    437449     
    438450def test_meanSpikeCount(): 
    439     orig_rank = control.rank 
    440     control.rank = lambda: 0 
     451    orig_rank = MockPopulation._simulator.state.mpi_rank 
     452    MockPopulation._simulator.state.mpi_rank = 0 
    441453    p = MockPopulation() 
    442454    p.recorders = {'spikes': Mock()} 
    443455    p.recorders['spikes'].count = Mock(return_value={0: 2, 1: 5}) 
    444456    assert_equal(p.meanSpikeCount(), 3.5) 
    445     control.rank = orig_rank 
     457    MockPopulation._simulator.state.mpi_rank = orig_rank 
    446458 
    447459def test_meanSpikeCount_on_slave_node(): 
    448     orig_rank = control.rank 
    449     control.rank = lambda: 1 
     460    orig_rank = MockPopulation._simulator.state.mpi_rank 
     461    MockPopulation._simulator.state.mpi_rank = 1 
    450462    p = MockPopulation() 
    451463    p.recorders = {'spikes': Mock()} 
    452464    p.recorders['spikes'].count = Mock(return_value={0: 2, 1: 5}) 
    453465    assert p.meanSpikeCount() is numpy.NaN 
    454     control.rank = orig_rank 
     466    MockPopulation._simulator.state.mpi_rank = orig_rank 
    455467     
    456468def test_inject(): 
     
    469481def test_save_positions(): 
    470482    import os 
    471     orig_rank = control.rank 
    472     control.rank = lambda: 0 
     483    orig_rank = MockPopulation._simulator.state.mpi_rank 
     484    MockPopulation._simulator.state.mpi_rank = 0 
    473485    p = MockPopulation() 
    474486    p.all_cells = numpy.array([34, 45, 56, 67]) 
     
    480492    assert_equal(output_file.write.call_args[0][1], {'population': p.label}) 
    481493    # arguably, the first column should contain indices, not ids. 
    482     control.rank = orig_rank 
     494    MockPopulation._simulator.state.mpi_rank = orig_rank 
  • trunk/test/unittests/test_connectors.py

    r999 r1000  
    6969 
    7070class MockProjection(object): 
     71    _simulator = MockSimulator 
    7172     
    7273    def __init__(self, pre, post): 
     
    102103 
    103104 
     105def setup(): 
     106    connectors.Connector._simulator = MockSimulator 
     107    connectors.ConnectionAttributeGenerator._simulator = MockSimulator 
     108 
     109def teardown(): 
     110    del connectors.Connector._simulator 
     111    del connectors.ConnectionAttributeGenerator._simulator 
     112 
    104113class TestOneToOneConnector(object): 
    105114 
    106115    def setup(self): 
    107         common.control.simulator = MockSimulator 
    108116        self.prj = MockProjection(MockPre(5), MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
    109117 
     
    129137 
    130138    def setup(self): 
    131         common.control.simulator = MockSimulator 
    132139        self.prj = MockProjection(MockPre(4), MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
    133140 
     
    182189        C = connectors.AllToAllConnector(weights=0.1, delays=None) 
    183190        assert_equal(C.weights, 0.1) 
    184         assert_equal(C.delays, common.get_min_delay()) 
     191        assert_equal(C.delays, C._simulator.state.min_delay) 
    185192        assert C.safe 
    186193        assert C.allow_self_connections 
     
    202209 
    203210    def setup(self): 
    204         common.control.simulator = MockSimulator 
    205211        self.prj = MockProjection(MockPre(4), 
    206212                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
     
    222228 
    223229    def setup(self): 
    224         common.control.simulator = MockSimulator 
    225230        self.prj = MockProjection(MockPre(4), 
    226231                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
     
    243248     
    244249    def setup(self): 
    245         common.control.simulator = MockSimulator 
    246250        self.prj = MockProjection(MockPre(4), 
    247251                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
     
    282286     
    283287    def setup(self): 
    284         common.control.simulator = MockSimulator 
    285288        self.prj = MockProjection(MockPre(4), 
    286289                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
     
    321324     
    322325    def setup(self): 
    323         common.control.simulator = MockSimulator 
    324326        self.prj = MockProjection(MockPre(4), 
    325327                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
     
    352354     
    353355    def setup(self): 
    354         common.control.simulator = MockSimulator 
    355356        self.prj = MockProjection(MockPre(4), 
    356357                                  MockPost(numpy.array([0,1,0,1,0], dtype=bool))) 
  • trunk/test/unittests/test_population.py

    r999 r1000  
    11from pyNN import errors, random, standardmodels, space 
    2 from pyNN.common import control, populations 
     2from pyNN.common import populations 
    33from nose.tools import assert_equal, assert_raises 
    44import numpy 
     
    66from pyNN.utility import assert_arrays_equal 
    77 
     8 
     9class MockSimulator(object): 
     10    class MockState(object): 
     11        mpi_rank = 1 
     12        num_processes = 3 
     13    state = MockState() 
    814 
    915class MockID(int, populations.IDMixin): 
     
    1521 
    1622class MockPopulation(populations.Population): 
     23    _simulator = MockSimulator 
    1724    recorder_class = Mock() 
    1825    initialize = Mock() 
     26     
     27    def _get_view(self, selector, label=None): 
     28        return populations.PopulationView(self, selector, label) 
    1929     
    2030    def _create_cells(self, cellclass, cellparams, size): 
     
    107117 
    108118def test_id_to_local_index(): 
    109     orig_np = control.num_processes 
    110     control.num_processes = lambda: 5 
     119    orig_np = MockPopulation._simulator.state.num_processes 
     120    MockPopulation._simulator.state.num_processes = 5 
    111121    p = MockPopulation(11, MockStandardCell) 
    112122    # every 5th cell, starting with the 4th, is on this node. 
     
    114124    assert_equal(p.id_to_local_index(p[8]), 1) 
    115125     
    116     control.num_processes = lambda: 1 
     126    MockPopulation._simulator.state.num_processes = 1 
    117127    # only one node 
    118128    assert_equal(p.id_to_local_index(p[3]), 3) 
    119129    assert_equal(p.id_to_local_index(p[8]), 8) 
    120     control.num_processes = orig_np 
     130    MockPopulation._simulator.state.num_processes = orig_np 
    121131 
    122132def test_id_to_local_index_with_invalid_id(): 
    123     orig_np = control.num_processes 
    124     control.num_processes = lambda: 5 
     133    orig_np = MockPopulation._simulator.state.num_processes 
     134    MockPopulation._simulator.state.num_processes = 5 
    125135    p = MockPopulation(11, MockStandardCell) 
    126136    # every 5th cell, starting with the 4th, is on this node. 
    127137    assert_raises(ValueError, p.id_to_local_index, p[0]) 
    128     control.num_processes = orig_np 
     138    MockPopulation._simulator.state.num_processes = orig_np 
    129139 
    130140# test structure property 
  • trunk/test/unittests/test_projection.py

    r999 r1000  
    77from pyNN.utility import assert_arrays_equal 
    88 
    9 orig_rank = common.rank 
    10 orig_np = common.num_processes 
     9 
     10class MockSimulator(object): 
     11    class MockState(object): 
     12        mpi_rank = 1 
     13        num_processes = 3 
     14    state = MockState() 
     15 
    1116 
    1217def setup(): 
    13     common.rank = lambda: 1 
    14     common.num_processes = lambda: 3 
     18    common.Projection._simulator = MockSimulator 
     19 
    1520 
    1621def teardown(): 
    17     common.rank = orig_rank 
    18     common.num_processes = orig_np 
     22    del common.Projection._simulator 
     23     
    1924 
    2025class MockStandardCell(standardmodels.StandardCellType): 
  • trunk/test/unittests/test_recording.py

    r922 r1000  
    99if MPI: 
    1010    mpi_comm = recording.mpi_comm 
     11 
     12def setup(): 
     13    recording.Recorder._simulator = MockSimulator(mpi_rank=0) 
     14     
     15def teardown(): 
     16    del recording.Recorder._simulator 
    1117 
    1218#def test_rename_existing(): 
     
    116122 
    117123def test_write__with_filename__compatible_output__gather__onroot(): 
    118     recording.simulator = MockSimulator(mpi_rank=0) 
    119124    orig_metadata = recording.Recorder.metadata 
    120125    recording.Recorder.metadata = {'a': 2, 'b':3} 
  • trunk/test/unittests/test_simulation_control.py

    r999 r1000  
    3232         
    3333def test_end(): 
    34     assert_raises(NotImplementedError, common.end) 
     34    assert_raises(NotImplementedError, common.control.end) 
    3535         
    3636def test_run(): 
    37     assert_raises(NotImplementedError, common.run, 10.0) 
     37    assert_raises(NotImplementedError, common.control.run, 10.0) 
    3838                
    3939def test_reset(): 
    40     common.control.simulator = MockSimulator() 
    41     common.reset() 
    42     assert common.control.simulator.reset_called 
     40    simulator = MockSimulator() 
     41    reset = common.control.build_reset(simulator) 
     42    reset() 
     43    assert simulator.reset_called 
    4344     
    4445def test_initialize(): 
     
    4647    common.initialize(p, 'v', -65.0) 
    4748    assert p.initializations == [('v', -65.0)] 
    48      
     49 
     50 
    4951def test_current_time(): 
    50     common.control.simulator = MockSimulator() 
    51     common.get_current_time() 
    52     assert_equal(common.control.simulator.state.accesses, ['t']) 
     52    simulator = MockSimulator() 
     53    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     54    get_current_time() 
     55    assert_equal(simulator.state.accesses, ['t']) 
    5356     
    5457def test_time_step(): 
    55     common.control.simulator = MockSimulator() 
    56     common.get_time_step() 
    57     assert_equal(common.control.simulator.state.accesses, ['dt']) 
     58    simulator = MockSimulator() 
     59    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     60    get_time_step() 
     61    assert_equal(simulator.state.accesses, ['dt']) 
    5862     
    5963def test_min_delay(): 
    60     common.control.simulator = MockSimulator() 
    61     common.get_min_delay() 
    62     assert_equal(common.control.simulator.state.accesses, ['min_delay']) 
     64    simulator = MockSimulator() 
     65    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     66    get_min_delay() 
     67    assert_equal(simulator.state.accesses, ['min_delay']) 
    6368 
    6469def test_max_delay(): 
    65     common.control.simulator = MockSimulator() 
    66     common.get_max_delay() 
    67     assert_equal(common.control.simulator.state.accesses, ['max_delay']) 
     70    simulator = MockSimulator() 
     71    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     72    get_max_delay() 
     73    assert_equal(simulator.state.accesses, ['max_delay']) 
    6874     
    6975def test_num_processes(): 
    70     common.control.simulator = MockSimulator() 
    71     common.num_processes() 
    72     assert_equal(common.control.simulator.state.accesses, ['num_processes']) 
     76    simulator = MockSimulator() 
     77    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     78    num_processes() 
     79    assert_equal(simulator.state.accesses, ['num_processes']) 
    7380     
    7481def test_rank(): 
    75     common.control.simulator = MockSimulator() 
    76     common.rank() 
    77     assert_equal(common.control.simulator.state.accesses, ['mpi_rank']) 
     82    simulator = MockSimulator() 
     83    get_current_time, get_time_step, get_min_delay, get_max_delay, num_processes, rank = common.control.build_state_queries(simulator) 
     84    rank() 
     85    assert_equal(simulator.state.accesses, ['mpi_rank']) 
  • trunk/test/unittests/test_utility_functions.py

    r999 r1000  
    7575    # need to check that a log message was created 
    7676    assert_equal(4.3, common.check_weight(4.3, 'excitatory', is_conductance=None)) 
    77      
    78 def test_check_delay(): 
    79     assert_equal(common.check_delay(None), MIN_DELAY) 
    80     assert_equal(common.check_delay(2*MIN_DELAY), 2*MIN_DELAY) 
    81     assert_raises(errors.ConnectionError, common.check_delay, 0.5*MIN_DELAY) 
    82     assert_raises(errors.ConnectionError, common.check_delay, 2.0*MAX_DELAY)