Changeset 396 for trunk

Show
Ignore:
Timestamp:
06/09/09 09:57:02 (3 years ago)
Author:
pierre
Message:

Force all the numpy arrays of the SpikeList, containing only int and dt rounded spike times to be 32 bits float in order to save some memory. Add also a psth method, test should follows...

Location:
trunk/src
Files:
2 modified

Legend:

Unmodified
Added
Removed
  • trunk/src/io.py

    r387 r396  
    172172                    data.append(id) 
    173173        logging.debug("Loaded %d lines of data from %s" % (len(data), self)) 
    174         return data 
     174        return numpy.array(data, numpy.float32) 
    175175 
    176176    def write(self, object): 
  • trunk/src/signals/spikes.py

    r395 r396  
    8686        self.t_start     = t_start 
    8787        self.t_stop      = t_stop 
    88         self.spike_times = numpy.array(spike_times, float) 
     88        self.spike_times = numpy.array(spike_times, numpy.float32) 
    8989 
    9090        # If t_start is not None, we resize the spike_train keeping only 
     
    584584        """ 
    585585        N              = (self.t_stop-self.t_start)/dt 
    586         vec_1          = numpy.zeros(N, float) 
    587         vec_2          = numpy.zeros(N, float) 
     586        vec_1          = numpy.zeros(N, numpy.float32) 
     587        vec_2          = numpy.zeros(N, numpy.float32) 
    588588        result         = numpy.zeros(N, float) 
    589589        idx_spikes     = numpy.array(self.spike_times/dt,int) 
     
    740740         
    741741        if not hasattr(spikes, 'size'): # is an array: 
    742             spikes = numpy.array(spikes) 
     742            spikes = numpy.array(spikes, numpy.float32) 
    743743        N = len(spikes) 
    744744         
     
    768768                [0,1,2,3,....,9999] 
    769769        """ 
    770         return numpy.array(self.spiketrains.keys()) 
     770        return numpy.array(self.spiketrains.keys(), int) 
    771771 
    772772    def copy(self): 
     
    787787        if len(self) > 0: 
    788788            if self.t_start is None: 
    789                 start_times = numpy.array([self.spiketrains[idx].t_start for idx in self.id_list()]) 
     789                start_times = numpy.array([self.spiketrains[idx].t_start for idx in self.id_list()], numpy.float32) 
    790790                self.t_start = numpy.min(start_times) 
    791791                logging.debug("Warning, t_start is infered from the data : %f" %self.t_start) 
     
    793793                    self.spiketrains[id].t_start = self.t_start 
    794794            if self.t_stop is None: 
    795                 stop_times = numpy.array([self.spiketrains[idx].t_stop for idx in self.id_list()]) 
     795                stop_times = numpy.array([self.spiketrains[idx].t_stop for idx in self.id_list()], numpy.float32) 
    796796                self.t_stop  = numpy.max(stop_times) 
    797797                logging.debug("Warning, t_stop  is infered from the data : %f" %self.t_stop) 
     
    14621462        if newnum: 
    14631463            M -= 1 
    1464         spike_hist = numpy.zeros((N, M), float) 
     1464        spike_hist = numpy.zeros((N, M), numpy.float32) 
    14651465        subplot    = get_display(display) 
    14661466        for idx,id in enumerate(self.id_list()): 
     
    19611961            set_axis_limits(subplot, t_start-0.05*length, t_stop+0.05*length, min_id-2, max_id+2) 
    19621962            pylab.draw() 
     1963 
     1964 
     1965    def psth(self, events, average=True, time_bin=2, t_min=50, t_max=50, display = False, kwargs={}): 
     1966        """ 
     1967        Return the psth of the cells contained in the SpikeList according to selected events,  
     1968        on a time window t_spikes - tmin, t_spikes + tmax 
     1969        Can return either the averaged psth (average = True), or an array of all the 
     1970        psth triggered by all the spikes. 
     1971             
     1972        Inputs: 
     1973            events  - Can be a SpikeTrain object (and events will be the spikes) or just a list  
     1974                      of times 
     1975            average - If True, return a single vector of the averaged waveform. If False,  
     1976                      return an array of all the waveforms. 
     1977            t_min   - Time (>0) to average the signal before an event, in ms (default 0) 
     1978            t_max   - Time (>0) to average the signal after an event, in ms  (default 100) 
     1979            display - if True, a new figure is created. Could also be a subplot. 
     1980            kwargs  - dictionary contening extra parameters that will be sent to the plot  
     1981                      function 
     1982             
     1983        Examples: 
     1984            >> vm.psth(spktrain, average=False, t_min = 50, t_max = 150) 
     1985            >> vm.psth(spktrain, average=True) 
     1986            >> vm.psth(range(0,1000,10), average=False, display=True) 
     1987             
     1988        See also 
     1989            SpikeTrain.spike_histogram 
     1990        """ 
     1991         
     1992        if isinstance(events, SpikeTrain): 
     1993            events = events.spike_times 
     1994        assert (t_min >= 0) and (t_max >= 0), "t_min and t_max should be greater than 0" 
     1995        assert len(events) > 0, "events should not be empty and should contained at least one element" 
     1996 
     1997        spk_hist = self.spike_histogram(time_bin) 
     1998        subplot  = get_display(display) 
     1999        count    = 0 
     2000        result   = numpy.zeros((len(self), (t_max+t_min)/time_bin), numpy.float32) 
     2001        t_min_l  = numpy.floor(t_min/time_bin) 
     2002        t_max_l  = numpy.floor(t_max/time_bin) 
     2003        for ev in events: 
     2004           ev = int((ev-self.t_start)/time_bin) 
     2005           if (ev > self.t_start + t_min_l) and ev < self.t_stop-t_max_l: 
     2006               count  += 1 
     2007               result += spk_hist[:,(ev-t_min_l):ev+t_max_l] 
     2008        result /= count 
     2009        if average: 
     2010           result = numpy.mean(result, 0) 
     2011             
     2012        if not subplot or not HAVE_PYLAB: 
     2013            return result 
     2014        else: 
     2015            xlabel = "Time (ms)" 
     2016            ylabel = "PSTH" 
     2017            time   = numpy.linspace(-t_min, t_max, (t_min+t_max)/time_bin) 
     2018            set_labels(subplot, xlabel, ylabel) 
     2019            if average: 
     2020                subplot.plot(time, result, **kwargs) 
     2021            else: 
     2022                for idx in xrange(len(result)): 
     2023                    subplot.plot(time, result[idx,:], c='0.5', **kwargs) 
     2024                    subplot.hold(1) 
     2025                result = numpy.mean(result, 0) 
     2026                subplot.plot(time, result, c='k', **kwargs) 
     2027            xmin, xmax, ymin, ymax = subplot.axis() 
     2028            subplot.plot([0,0],[ymin, ymax], c='r') 
     2029            set_axis_limits(subplot, -t_min, t_max, ymin, ymax) 
     2030            pylab.draw() 
     2031        return result 
    19632032 
    19642033 
     
    21822251            convert() 
    21832252        """ 
    2184         data = numpy.array(self.convert("[times, ids]")) 
     2253        data = numpy.array(self.convert("[times, ids]"), numpy.float32) 
    21852254        data = numpy.transpose(data) 
    21862255        return data