Changeset 453

Show
Ignore:
Timestamp:
08/23/10 14:25:44 (18 months ago)
Author:
mpereira
Message:

Added crosscorrelate() to analysis.py

Location:
trunk
Files:
1 added
3 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/analysis.py

    r451 r453  
    2222import os, numpy 
    2323from NeuroTools import check_dependency 
    24  
    2524 
    2625def arrays_almost_equal(a, b, threshold): 
     
    4544 
    4645    """ 
    47     assert(x.ndim == y.ndim, "Inconsistent shape !") 
     46    assert x.ndim == y.ndim, "Inconsistent shape !" 
    4847#    assert(x.shape == y.shape, "Inconsistent shape !") 
    4948    if axis is None: 
     
    7978    return iFxy/varxy 
    8079 
     80from NeuroTools.plotting import get_display, set_labels 
     81 
     82HAVE_PYLAB = check_dependency('pylab') 
     83 
     84 
     85def crosscorrelate(sua1, sua2, lag=None, n_pred=1, predictor=None, 
     86                   display=False, kwargs={}): 
     87    """ 
     88    Calculates the cross-correlation between two vectors containing event times. 
     89    Returns (int, int_, norm). See below for details. 
     90     
     91    Adapted from original script written by Martin P. Nawrot for the FIND MATLAB 
     92    toolbox. 
     93    See FIND - a unified framework for neural data analysis, 
     94        Meier R, Egert U, Aertsen A, Nawrot MP; Neural Netw. 2008 Oct; 
     95        21(8):1085-93.  
     96      
     97    Inputs: 
     98        sua1      - array of event times. Can be either a column/row vector. 
     99        sua2      - array of event times. Can be either a column/row vector. 
     100                    If sua2 == sua1 the result is the 
     101                    autocorrelogram. 
     102        lag       - the max. lag for which relative event timing is considered 
     103                    with a max. difference of +/- lag. A default lag is computed 
     104                    from the inter-event interval of the longer of the two sua 
     105                    arrays  
     106        n_pred    - number of surrogate compilations for the predictor. This 
     107                    influences the total length of the predictor output array 
     108                    int_ 
     109        predictor - string array determines the type of bootstrap predictor to 
     110                    be used: 
     111                        shuffle - shuffles inter-event intervals of the longer 
     112                                  input array and calculates relative 
     113                                  differences with the shorter input array. 
     114                                  n_pred determines the number of repeated 
     115                                  shufflings, resulting differences are pooled 
     116                                  from all repeated shufflings 
     117        display   - if True the corresponding plots will be displayed. If False, 
     118                    int, int_ and norm will be returned. 
     119                    when display = True and n_pred > 1, the averaged predictor 
     120                    will be plotted. 
     121        kwargs    - arguments to be passed to numpy.histogram. 
     122 
     123    Outputs: 
     124        int  - accumulated differences of events in sua1 minus the events in 
     125               sua2. Thus positive values of int relate to events of sua2 that 
     126               lead events of sua1. Units are the same as the input arrays. 
     127        int_ - predictior: accumulated differences based on the prediction 
     128               method. The length of int_ is n_pred * length(int).  Units are 
     129               the same as the input arrays. 
     130        norm - normalization factor used to scale the bin heights in int and 
     131               int_. int/norm and int_/norm correspond to the linear 
     132               correlation coefficient. 
     133     
     134    Examples: 
     135        >> crosscorrelate(numpy_array1, numpy_array2) 
     136        >> crosscorrelate(spike_train1.spike_times, spike_train2.spike_times) 
     137        >> crosscorrelate(spike_train1.spike_times, spike_train2.spike_times, 
     138                          lag = 150.0) 
     139        >> crosscorrelate(spike_train1.spike_times, spike_train2.spike_times, 
     140                          display=True, kwargs={'bins':100}) 
     141             
     142    See also: 
     143        ccf 
     144    """     
     145    assert predictor is 'shuffle' or predictor is None, "predictor must be \ 
     146    either None or 'shuffle'. Other predictors are not yet implemented." 
     147     
     148    #Check whether sua1 and sua2 are SpikeTrains or arrays 
     149    sua = [] 
     150    for x in (sua1, sua2): 
     151        if x.ndim == 1: 
     152            sua.append(x) 
     153        elif x.ndim == 2 and (x.shape[0] == 1 or x.shape[1] == 1): 
     154            sua.append(x.ravel()) 
     155        else: 
     156            raise TypeError("sua1 and sua2 must be either instances of the \ 
     157                            SpikeTrain class or column/row vectors") 
     158    sua1 = sua[0] 
     159    sua2 = sua[1] 
     160     
     161    if sua1.size < sua2.size: 
     162        if lag is None: 
     163            lag = numpy.ceil(10*numpy.mean(numpy.diff(sua1))) 
     164        reverse = False 
     165    else: 
     166        if lag is None: 
     167            lag = numpy.ceil(20*numpy.mean(numpy.diff(sua2))) 
     168        sua1, sua2 = sua2, sua1 
     169        reverse = True 
     170             
     171    #construct predictor 
     172    if predictor is 'shuffle': 
     173        isi = numpy.diff(sua2) 
     174        sua2_ = numpy.array([]) 
     175        for ni in xrange(1,n_pred+1): 
     176            idx = numpy.random.permutation(isi.size-1) 
     177             
     178            sua2_ = numpy.append(sua2_, numpy.add(numpy.insert( 
     179                (numpy.cumsum(isi[idx])), 0, 0), sua2.min() + ( 
     180                numpy.random.exponential(isi.mean())))) 
     181             
     182    #calculate cross differences in spike times 
     183    int = numpy.array([]) 
     184    int_ = numpy.array([]) 
     185    for k in xrange(0, sua1.size): 
     186        int = numpy.append(int, sua1[k] - sua2[numpy.nonzero( 
     187            (sua2 > sua1[k] - lag) & (sua2 < sua1[k] + lag))]) 
     188    if predictor == 'shuffle': 
     189        for k in xrange(0, sua1.size): 
     190            int_ = numpy.append(int_, sua1[k] - sua2_[numpy.nonzero( 
     191                (sua2_ > sua1[k] - lag) & (sua2_ < sua1[k] + lag))]) 
     192     
     193    if reverse is True: 
     194        int = -int 
     195        int_ = -int_ 
     196         
     197    norm = numpy.sqrt(sua1.size * sua2.size) 
     198     
     199    # Plot the results if display=True 
     200    subplot = get_display(display) 
     201    if not subplot or not HAVE_PYLAB: 
     202        return int, int_, norm 
     203    else: 
     204        # Plot the cross-correlation 
     205        counts, bin_edges = numpy.histogram(int, **kwargs) 
     206        counts = counts / norm 
     207        xlabel = "Time" 
     208        ylabel = "Cross-correlation coefficient" 
     209        #NOTE: the x axis corresponds to the upper edge of each bin 
     210        subplot.plot(bin_edges[1:], counts, label='cross-correlation') 
     211        if predictor is None: 
     212            set_labels(subplot, xlabel, ylabel) 
     213            pylab.draw() 
     214        elif predictor is 'shuffle':             
     215            # Plot the predictor 
     216            norm_ = norm * n_pred 
     217            counts_, bin_edges_ = numpy.histogram(int_, **kwargs) 
     218            counts_ = counts_ / norm_ 
     219            subplot.plot(bin_edges_[1:], counts_, label='predictor') 
     220            subplot.legend() 
     221            pylab.draw() 
    81222 
    82223def _dict_max(D): 
  • trunk/src/signals/spikes.py

    r446 r453  
    2828import os, re, numpy 
    2929import logging 
    30 from NeuroTools import check_dependency, check_numpy_version, analysis 
     30from NeuroTools import check_dependency, check_numpy_version 
     31try: 
     32    from NeuroTools import analysis 
     33except ImportError: 
     34    pass 
     35 
    3136from NeuroTools.io import * 
    3237from NeuroTools.plotting import get_display, set_axis_limits, set_labels, SimpleMultiplot, progress_bar 
  • trunk/test/test_analysis.py

    r451 r453  
    77from numpy import pi, sin 
    88 
    9 from NeuroTools import analysis 
    10  
     9from NeuroTools import analysis, signals 
    1110 
    1211# test simple_frequency_spectrum 
     
    2726        self.exp_reversed = scipy.io.loadmat( 
    2827            'analysis/make_kernel/exp_reversed.mat') 
     28         
     29        #Used for the testcases of the crosscorrelate function 
     30        spk = signals.load_spikelist('analysis/crosscorrelate/spike_data') 
     31        self.spk0 = spk[0].spike_times 
     32        self.spk1 = spk[1].spike_times 
    2933 
    3034    def testSimpleFrequencySpectrum(self): 
     
    153157        self.assertEqual(true_m_idx, m_idx, 3) 
    154158         
     159    def testCrosscorrelateNoLag(self): 
     160            int, int_, norm = analysis.crosscorrelate(self.spk0, self.spk1) 
     161            #The following are output was generated with the FIND MATLAB toolbox 
     162            matlab_int = numpy.loadtxt('analysis/crosscorrelate/out_matlab_int') 
     163            numpy.testing.assert_array_almost_equal(int, matlab_int, 
     164                                                    decimal=3) 
     165            #The int_ output has a random component and for this reason the test 
     166            # cases are not as trivial 
     167             
     168    def testCrosscorrelateLag100(self): 
     169        """Test case with lag within the length of the input array 
     170        """ 
     171        int, int_, norm, = analysis.crosscorrelate(self.spk0, self.spk1, 
     172                                                   lag=100.0) 
     173        matlab_int = numpy.loadtxt('analysis/crosscorrelate/out_matlab_int_lag_100') 
     174        numpy.testing.assert_array_almost_equal(int, matlab_int, decimal = 3) 
     175             
     176    def testCrosscorrelateLag500(self): 
     177        """Test case with lag is higher than the trial length 
     178        """ 
     179        int, int_, norm = analysis.crosscorrelate(self.spk0, self.spk1, 
     180                                                  lag=500.0) 
     181        matlab_int = numpy.loadtxt('analysis/crosscorrelate/out_matlab_int_lag_500') 
     182        numpy.testing.assert_array_almost_equal(int, matlab_int, decimal = 3) 
     183             
    155184if __name__ == "__main__": 
    156185    unittest.main()