Changeset 431

Show
Ignore:
Timestamp:
08/26/09 15:10:23 (3 years ago)
Author:
pierre
Message:

Add the interval notion to the AnalogSignal? object, now debugged. Start to fix the AnalogSignalList? objects...

Location:
branches/interval/src/signals
Files:
3 modified

Legend:

Unmodified
Added
Removed
  • branches/interval/src/signals/analogs.py

    r430 r431  
    3131""" 
    3232 
    33 import os, re, numpy 
     33import os, re, numpy, copy 
    3434from NeuroTools import check_dependency, check_numpy_version, analysis 
    3535from NeuroTools.io import * 
     
    7878    def __init__(self, signal, dt, t_start=0, t_stop=None, interval=None): 
    7979        #logging.debug("Creating AnalogSignal. len(signal)=%d, dt=%g, t_start=%g, t_stop=%s" % (len(signal), dt, t_start, t_stop)) 
    80         self.signal  = signal 
    81         self.dt      = float(dt) 
     80        self.signal   = signal 
     81        self.dt       = float(dt) 
    8282        self.interval = self.create_intervals(t_start, t_stop, interval) 
    8383        self.init_times() 
     
    8888        if self.t_start >= self.t_stop: 
    8989            raise Exception("Incompatible time interval for the creation of the AnalogSignal. t_start=%s, t_stop=%s" % (self.t_start, self.t_stop)) 
     90 
     91        self.times  = self.interval.time_axis() 
     92        idx         = self.interval.idx_slice_times(self.times) 
     93        self.signal = self.signal[idx] 
    9094 
    9195    def create_intervals(self, t_start=None, t_stop=None, interval=None): 
     
    9599            # interval is fully defined by the user 
    96100            if type(interval) == Interval: 
    97                 interval_out = interval.copy() 
     101                interval_out = interval.copy() 
    98102            else : 
    99                 interval_out = Interval(interval)  
     103                interval_out = Interval(interval) 
    100104        else: 
    101105            if t_start is None and t_stop is None : 
     
    116120                    else : 
    117121                        interval_out = Interval(t_start) 
     122         
    118123        return interval_out 
    119124 
     
    135140        Return the duration of the SpikeTrain 
    136141        """ 
    137         return self.t_stop - self.t_start 
     142        return self.interval.total_duration() 
    138143 
    139144    def __str__(self): 
     
    156161        Return a copy of the AnalogSignal object 
    157162        """ 
    158         return AnalogSignal(self.signal, self.dt, self.t_start, self.t_stop) 
     163        return copy.deepcopy(self) 
    159164 
    160165    def time_axis(self, normalized=False): 
     
    169174            norm = 0. 
    170175 
    171         time_pavement = numpy.arange(self.t_start-norm, self.t_stop-norm, self.dt) 
    172         return SpikeTrain(time_pavement, interval=self.interval).spike_times 
     176        time_pavement = numpy.arange(self.t_start, self.t_stop, self.dt) - norm 
     177        return self.interval.slice_times(time_pavement) 
    173178     
    174179    def time_offset(self, offset): 
     
    239244        assert t_stop <= self.t_stop 
    240245        assert t_stop > t_start 
    241          
    242         t = self.time_axis() 
    243         i_start = int((t_start-self.t_start)/self.dt) 
    244         i_stop = int((t_stop-self.t_start)/self.dt) 
    245         signal = self.signal[i_start:i_stop] 
    246  
    247         return AnalogSignal(signal, self.dt, t_start, t_stop) 
     246        interval = Interval((t_start, t_stop)) 
     247        return self.interval_slice(interval) 
    248248 
    249249    def interval_slice(self, interval): 
     
    261261            time_slice 
    262262        """ 
    263         signal_list = [] 
    264  
    265         if type(interval) is not Interval : 
    266             interval = Interval(interval) 
    267  
    268         for itv in interval.interval_data : 
    269             signal_list.append(self.signal[(itv[0]-self.t_start)/self.dt:(itv[1]-self.t_start)/self.dt]) 
    270  
    271         return AnalogSignal(signal_list, self.dt, interval) 
    272  
    273     def threshold_detection(self, threshold=None, format=None,sign='above'): 
     263        interval = interval.intersect(self.interval) 
     264        idx      = interval.idx_slice_times(self.times) 
     265        return AnalogSignal(self.signal[idx], self.dt, interval=interval) 
     266 
     267    def threshold_detection(self, threshold=None, format=None, sign='above'): 
    274268        """ 
    275269        Returns the times when the analog signal crosses a threshold. 
     
    353347        # recalculate everything into timesteps, is more stable against rounding errors 
    354348        # and subsequent cutouts with different sizes 
    355         events = numpy.floor(numpy.array(events)/self.dt) 
    356         t_min_l = numpy.floor(t_min/self.dt) 
    357         t_max_l = numpy.floor(t_max/self.dt) 
    358         t_start = numpy.floor(self.t_start/self.dt) 
    359         t_stop = numpy.floor(self.t_stop/self.dt) 
    360          
    361         for spike in events: 
    362             if ((spike-t_min_l) >= t_start) and ((spike+t_max_l) < t_stop): 
    363                 spike = spike - t_start 
    364                 if average: 
    365                     result += self.signal[(spike-t_min_l):(spike+t_max_l)] 
    366                 else: 
    367                     result.append(self.signal[(spike-t_min_l):(spike+t_max_l)]) 
    368                 Nspikes += 1 
     349        t_min_l  = numpy.floor(t_min/self.dt) 
     350        t_max_l  = numpy.floor(t_max/self.dt) 
     351        result   = numpy.zeros((t_min_l+t_max_l), numpy.float32) 
     352        events_interval = copy.copy(self.interval) 
     353        events_interval.offset_start(t_min, colapse=True) 
     354        events_interval.offset_stop(-t_max, colapse=True) 
     355        events = events_interval.slice_times(events) 
     356        assert len(events) > 0, "the PSTH windows [event-t_min, event+t_max] should be included within the spike train interval" 
     357        for ev in events: 
     358            ev = numpy.floor((ev - self.interval.t_start())/self.dt) 
     359            if average: 
     360                result += self.signal[(spike-t_min_l):(spike+t_max_l)] 
     361            else: 
     362                result.append(self.signal[(spike-t_min_l):(spike+t_max_l)]) 
     363            Nspikes += 1 
    369364        if average: 
    370365            result = result/Nspikes 
     
    415410        assert (t_min >= 0) and (t_max >= 0), "t_min and t_max should be greater than 0" 
    416411        assert len(events) > 0, "events should not be empty and should contained at least one element" 
    417          
     412        events_interval = copy.copy(self.interval) 
     413        events_interval.offset_start(t_min, colapse=True) 
     414        events_interval.offset_stop(-t_max, colapse=True) 
     415        events = events_interval.slice_times(events) 
     416        assert len(events) > 0, "the PSTH windows [event-t_min, event+t_max] should be included within the spike train interval" 
    418417        result = {} 
    419         for index, spike in enumerate(events): 
    420             if ((spike-t_min) >= self.t_start) and ((spike+t_max) < self.t_stop): 
    421                 spike = spike - self.t_start 
    422                 t_start_new = (spike-t_min) 
    423                 t_stop_new = (spike+t_max) 
    424                 result[index] = self.time_slice(t_start_new, t_stop_new) 
     418        for index, ev in enumerate(events): 
     419            t_start_new = (ev-t_min) 
     420            t_stop_new  = (ev+t_max) 
     421            result[index] = self.time_slice(t_start_new, t_stop_new) 
    425422        return result 
    426          
    427              
     423 
     424 
     425 
    428426class AnalogSignalList(object): 
    429427    """ 
     
    537535        if len(self) > 0: 
    538536            errmsgs = [] 
    539             print numpy.shape(val.signal) 
    540537            val = val.interval_slice(self.interval) 
    541             print numpy.shape(val.signal) 
    542538            if not val.interval.is_equal(self.interval) : 
    543539                raise Exception("the provided signal doesn't cover the AnalogSignalList interval") 
  • branches/interval/src/signals/intervals.py

    r427 r431  
    55    from interval import *  
    66 
    7 import numpy 
     7import numpy, copy 
    88 
    99class Interval(object): 
     
    5050    def __iter__(self): 
    5151        return iter(self.sub_intervals) 
     52     
     53    def __contains__(self, other): 
     54        return numpy.all(numpy.any(x.inf <= y.inf and y.sup <= x.sup for x in self.interval_data) for y in other) 
    5255     
    5356    def __getslice__(self, i, j): 
     
    102105        Return a copy of the SpikeTrain object 
    103106        """ 
    104         return Interval(self.sub_intervals) 
     107        return copy.deepcopy(self) 
    105108 
    106109    def offset_start(self, shift, from_stop=False, colapse=False) : 
     
    163166 
    164167    def slice_times(self, times) : 
    165         if times.__class__.__name__ == 'SpikeTrain' : 
    166             times = numpy.array(times.spike_times) 
    167         else : 
    168             times = numpy.array(times) 
    169168        spikes_selector = numpy.zeros(len(times), dtype=numpy.bool) 
     169        times = numpy.array(times) 
    170170        if HAVE_INTERVAL: 
    171171            for itv in self.interval_data : 
     
    174174            spikes_selector = (times >= self.t_start()) & (times <= self.t_stop()) 
    175175        return numpy.extract(spikes_selector, times) 
     176     
     177    def time_axis(self, dt=0.1): 
     178        result = numpy.array([], float) 
     179        if HAVE_INTERVAL: 
     180            for itv in self.interval_data : 
     181                result = numpy.concatenate((result, numpy.arange(itv[0],itv[1],dt))) 
     182        else: 
     183            return numpy.arange(self.t_start(),self.t_stop(),dt) 
     184        return result 
    176185 
    177186    def idx_slice_times(self, times): 
    178187        spikes_selector = numpy.zeros(len(times), dtype=numpy.bool) 
    179188        if HAVE_INTERVAL: 
    180             for itv in self.interval_data : 
    181                 spikes_selector = spikes_selector + (times > itv[0])*(times <= itv[1]) 
     189            for itv in self.interval_data: 
     190                spikes_selector = spikes_selector + (times >= itv[0])*(times < itv[1]) 
    182191        else: 
    183192            spikes_selector = (times >= self.t_start()) & (times <= self.t_stop()) 
     
    186195    def is_equal(self, itv) : 
    187196        return self.sub_intervals == itv.sub_intervals 
    188  
    189197 
    190198    def intersect(self, itv) : 
  • branches/interval/src/signals/spikes.py

    r430 r431  
    197197        Return a copy of the SpikeTrain object 
    198198        """ 
    199         return SpikeTrain(self.spike_times, interval=self.interval) 
     199        return copy.deepcopy(self) 
    200200 
    201201 
     
    869869        """ 
    870870 
    871         spklist = SpikeList([], [], None, None, self.interval, self.dimensions) 
    872         for id in self.id_list(): 
    873             spklist.append(float(id), self.spiketrains[float(id)]) 
    874         return spklist 
     871        copy.deepcopy(self) 
    875872 
    876873    def create_intervals(self, t_start=None, t_stop=None, interval=None):