Changeset 779

Show
Ignore:
Timestamp:
08/04/10 00:04:30 (18 months ago)
Author:
apdavison
Message:

Wrote some more Connector unit tests.

Location:
trunk
Files:
2 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/connectors.py

    r778 r779  
    177177 
    178178class DistanceMatrix(object): 
     179    # should probably move to space module 
    179180     
    180181    def __init__(self, B, space, mask=None): 
  • trunk/test/unittests/connectortests.py

    r701 r779  
    11from mpi4py import MPI 
    22import numpy 
    3 from pyNN.random import NumpyRNG, RandomDistribution 
    4 from pyNN.connectors2 import FixedProbabilityConnector, AllToAllConnector 
    5 from pyNN import common 
     3from pyNN.random import AbstractRNG, RandomDistribution 
     4from pyNN.connectors import FixedProbabilityConnector, AllToAllConnector, DistanceMatrix, \ 
     5                            ConnectionAttributeGenerator 
     6from pyNN import common, errors 
     7from pyNN.space import Space 
     8import unittest 
    69 
    710mpi_comm = MPI.COMM_WORLD 
     11 
     12def assert_arrays_almost_equal(a, b, threshold, msg=''): 
     13    if a.shape != b.shape: 
     14        raise unittest.TestCase.failureException("Shape mismatch: a.shape=%s, b.shape=%s" % (a.shape, b.shape)) 
     15    if not (abs(a-b) < threshold).all(): 
     16        err_msg = "%s != %s" % (a, b) 
     17        err_msg += "\nlargest difference = %g" % abs(a-b).max() 
     18        if msg: 
     19            err_msg += "\nOther information: %s" % msg 
     20        raise unittest.TestCase.failureException(err_msg) 
    821 
    922class MockSimulatorModule(object): 
    1023    class State(object): 
    1124        min_delay = 0.1 
     25        max_delay = 10.0 
    1226    state = State() 
    1327 
    1428common.simulator = MockSimulatorModule() 
    1529 
     30class MockRNG(AbstractRNG): 
     31     
     32    def __init__(self, seed): 
     33        self.current = seed 
     34         
     35    def next(self, n=1, distribution='uniform', parameters=[], mask_local=None): 
     36        start = self.current 
     37        self.current += n 
     38        return numpy.arange(start, self.current)     
     39 
     40class MockConnectionManager(object): 
     41        def __init__(self): 
     42            self.n = 0 
     43            self.weights = [] 
     44            self.delays = [] 
     45             
     46        def connect(self, src, targets, weights, delays): 
     47            assert len(targets) == len(weights) == len(delays) > 0, "%d %d %d" % (len(targets), len(weights), len(delays)) 
     48            self.n += len(targets) 
     49            self.weights.extend(weights) 
     50            self.delays.extend(delays) 
     51             
    1652class MockProjection(object): 
    1753     
    18     class MockConnectionManager(object): 
    19         def connect(self, src, targets, weights, delays): 
    20             assert len(targets) == len(weights) == len(delays), "%d %d %d" % (len(targets), len(weights), len(delays)) 
    21     connection_manager = MockConnectionManager() 
    22      
    23     def __init__(self, pre, post, rng): 
     54    def __init__(self, pre, post, rng, synapse_type): 
    2455        self.pre = pre 
    2556        self.post = post 
    2657        self.rng = rng 
     58        self.synapse_type = synapse_type 
     59        self.connection_manager = MockConnectionManager() 
     60 
     61    @property 
     62    def n(self): 
     63        return self.connection_manager.n 
    2764 
    2865class MockID(int): 
     66     
     67    def __init__(self, id): 
     68        self.id = id 
    2969     
    3070    @property 
    3171    def position(self): 
    32         return numpy.array([float(self), 0.0, 0.0]) 
    33          
     72        return self.parent.positions.T[self.id]         
    3473 
    3574class MockPopulation(object): 
     
    3776    def __init__(self, n): 
    3877        self.all_cells = numpy.array([MockID(i) for i in range(n)], dtype=MockID) 
     78        for i in range(n): 
     79            self.all_cells[i].parent = self 
    3980        self.positions = numpy.array([(i, 0.0, 0.0) for i in self.all_cells], dtype=float).T 
    4081        self._mask_local = numpy.array([i%mpi_comm.size == mpi_comm.rank for i in range(n)]) 
     
    4889        return self.all_cells.size 
    4990     
    50 p1 = MockPopulation(100) 
    51 p2 = MockPopulation(100) 
    52  
    53 rng = NumpyRNG(8569552) 
    54  
    55 weight_sources = [0.1, 
    56                   RandomDistribution('uniform', (0,1), rng), 
    57                   numpy.arange(0.0, 1.0, 1e-4).reshape(100,100), 
    58                   "exp(-(d*d)/1e4)"] 
    59  
    60 for weight_source in weight_sources: 
    61     connector = FixedProbabilityConnector(p_connect=0.3, 
    62                                           allow_self_connections=True, 
    63                                           weights=weight_source, 
    64                                           delays=0.2) 
    65     prj = MockProjection(p1, p2, rng) 
    66     connector.connect(prj) 
    67     connector = AllToAllConnector(allow_self_connections=True, 
    68                                   weights=weight_source, 
    69                                   delays=0.2) 
    70     prj = MockProjection(p1, p2, rng) 
    71     connector.connect(prj) 
     91    def __len__(self): 
     92        return self.all_cells.size 
     93         
     94 
     95class AllToAllConnectorTest(unittest.TestCase): 
     96     
     97    def setUp(self): 
     98        self.p1 = MockPopulation(17) 
     99        self.p2 = MockPopulation(13) 
     100        self.rng = MockRNG(0) 
     101         
     102     
     103    def test_create_with_delays_None(self): 
     104        connector = AllToAllConnector(weights=0.1, delays=None) 
     105        self.assertEqual(connector.weights, 0.1) 
     106        self.assertEqual(connector.delays, common.get_min_delay()) 
     107        self.assert_(connector.safe) 
     108        self.assert_(connector.allow_self_connections) 
     109         
     110    def test_create_with_delays_too_small(self): 
     111        self.assertRaises(errors.ConnectionError, 
     112                          AllToAllConnector, allow_self_connections=True, 
     113                          delays=0.0) 
     114 
     115    def test_create_with_list_delays_too_small(self): 
     116        self.assertRaises(errors.ConnectionError, 
     117                          AllToAllConnector, allow_self_connections=True, 
     118                          delays=[1.0, 1.0, 0.0]) 
     119 
     120    def test_connect_with_single_weight(self): 
     121        connector = AllToAllConnector(allow_self_connections=True, 
     122                                      weights=0.1) 
     123        prj = MockProjection(self.p1, self.p2, self.rng, "excitatory") 
     124        connector.connect(prj) 
     125        self.assertEqual(prj.n, self.p1.size*self.p2.size) 
     126        self.assertEqual(prj.connection_manager.weights, [[0.1]]*prj.n) 
     127         
     128    def test_connect_with_no_self_connections(self): 
     129        connector = AllToAllConnector(allow_self_connections=False, 
     130                                      weights=0.1) 
     131        prj = MockProjection(self.p1, self.p1, self.rng, "excitatory") 
     132        connector.connect(prj) 
     133        self.assertEqual(prj.n, self.p1.size*(self.p1.size-1)) 
     134        self.assertEqual(prj.connection_manager.weights, [[0.1]]*prj.n) 
     135         
     136    def test_connect_with_random_weights(self): 
     137        connector = AllToAllConnector(weights=RandomDistribution("uniform", [0.3, 0.4], rng=self.rng)) 
     138        prj = MockProjection(self.p1, self.p2, self.rng, "excitatory") 
     139        connector.connect(prj) 
     140        self.assertEqual(prj.n, self.p1.size*self.p2.size) 
     141        assert_arrays_almost_equal(numpy.array(prj.connection_manager.weights), numpy.arange(0, prj.n), 1e-6) 
     142 
     143    def test_connect_with_weight_array(self): 
     144        w_in = numpy.arange(self.p1.size*self.p2.size, 0.0, -1.0).reshape(self.p1.size, self.p2.size) 
     145        connector = AllToAllConnector(weights=w_in) 
     146        prj = MockProjection(self.p1, self.p2, self.rng, "excitatory") 
     147        connector.connect(prj) 
     148        self.assertEqual(prj.n, self.p1.size*self.p2.size) 
     149        assert_arrays_almost_equal(numpy.array(prj.connection_manager.weights), w_in.flatten(), 1e-6) 
     150 
     151    def test_connect_with_distance_dependent_weights_simple(self): 
     152        self.p1.positions = numpy.zeros((3,self.p1.size)) 
     153        connector = AllToAllConnector(weights="d*d") 
     154        prj = MockProjection(self.p1, self.p2, self.rng, "excitatory") 
     155        connector.connect(prj) 
     156        self.assertEqual(prj.n, self.p1.size*self.p2.size) 
     157        w_expected = numpy.array([w*w for w in range(self.p2.size)]*self.p1.size, float).flatten() 
     158        assert_arrays_almost_equal(numpy.array(prj.connection_manager.weights), 
     159                                   w_expected, 1e-6) 
     160 
     161    def test_connect_with_distance_dependent_weights(self): 
     162        # 3D position values, not just zeros and ones 
     163        self.fail() 
     164 
     165    def test_connect_with_distance_dependent_weights_and_delays(self): 
     166        # check how many times DistanceMatrix.as_array() gets called 
     167        self.fail() 
     168 
     169 
     170class FixedProbabilityConnectorTest(unittest.TestCase): 
     171 
     172    def test(self): 
     173        self.fail() 
     174         
     175class DistanceDependentProbabilityConnectorTest(unittest.TestCase): 
     176 
     177    def test(self): 
     178        self.fail() 
     179 
     180class FixedNumberPreConnectorTest(unittest.TestCase): 
     181 
     182    def test(self): 
     183        self.fail() 
     184         
     185class FixedNumberPostConnectorTest(unittest.TestCase): 
     186 
     187    def test(self): 
     188        self.fail() 
     189 
     190class OneToOneConnectorTest(unittest.TestCase): 
     191 
     192    def test(self): 
     193        self.fail() 
     194 
     195class SmallWorldConnectorTest(unittest.TestCase): 
     196 
     197    def test(self): 
     198        self.fail() 
     199 
     200class CSAConnectorTest(unittest.TestCase): 
     201 
     202    def test(self): 
     203        self.fail() 
     204 
     205class FromListConnectorTest(unittest.TestCase): 
     206 
     207    def test(self): 
     208        self.fail() 
     209 
     210class FromFileConnectorTest(unittest.TestCase): 
     211 
     212    def test(self): 
     213        self.fail() 
     214 
     215class TestDistanceMatrix(unittest.TestCase): 
     216     
     217    def test_really_simple0(self): 
     218        A = numpy.zeros((3,)) 
     219        B = numpy.zeros((3,5)) 
     220        D = DistanceMatrix(B, Space()) 
     221        D.set_source(A) 
     222        assert_arrays_almost_equal(D.as_array(), 
     223                                   numpy.zeros((5,), float), 
     224                                   1e-12) 
     225 
     226    def test_really_simple1(self): 
     227        A = numpy.ones((3,)) 
     228        B = numpy.zeros((3,5)) 
     229        D = DistanceMatrix(B, Space()) 
     230        D.set_source(A) 
     231        assert_arrays_almost_equal(D.as_array(), 
     232                                   numpy.sqrt(3*numpy.ones((5,), float)), 
     233                                   1e-12) 
     234 
     235 
     236class TestConnectionAttributeGenerator(unittest.TestCase): 
     237     
     238    def test_extract_with_simple_d_expr(self): 
     239        B = numpy.zeros((3,5)) 
     240        dist = DistanceMatrix(B, Space()) 
     241        gen = ConnectionAttributeGenerator( 
     242                  source="d*d", 
     243                  local_mask=None, 
     244                  safe=False) 
     245        dist.set_source(numpy.zeros((3,))) 
     246        assert_arrays_almost_equal(gen.extract(5, dist, sub_mask=None), 
     247                                   numpy.zeros((5,), float), 
     248                                   1e-12) 
     249        dist.set_source(numpy.array([1,0,0])) 
     250        assert_arrays_almost_equal(gen.extract(5, dist, sub_mask=None), 
     251                                   numpy.ones((5,), float), 
     252                                   1e-12) 
     253        dist.set_source(numpy.array([2,0,0])) 
     254        assert_arrays_almost_equal(gen.extract(5, dist, sub_mask=None), 
     255                                   4*numpy.ones((5,), float), 
     256                                   1e-12) 
     257 
     258     
     259# ============================================================================== 
     260if __name__ == "__main__": 
     261    unittest.main()