root/trunk/src/plotting.py

Revision 486, 15.3 KB (checked in by mpereira, 10 months ago)

major improvements of SpikeTrain.instantaneous_rate()

Line 
1"""
2NeuroTools.plotting
3===================
4
5This module contains a collection of tools for plotting and image processing that
6shall facilitate the generation and handling of NeuroTools data visualizations.
7It utilizes the Matplotlib and the Python Imaging Library (PIL) packages.
8
9
10Classes
11-------
12
13SimpleMultiplot     - object that creates and handles a figure consisting of multiple panels, all with the same datatype and the same x-range.
14
15
16Functions
17---------
18
19get_display         - returns a pylab object with a plot() function to draw the plots.
20progress_bar        - prints a progress bar to stdout, filled to the given ratio.
21pylab_params        - returns a dictionary with a set of parameters that help to nicely format figures by updating the pylab run command parameters dictionary 'pylab.rcParams'.
22set_axis_limits     - defines the axis limits in a plot.
23set_labels          - defines the axis labels of a plot.
24set_pylab_params    - updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams' in order to achieve nicely formatted figures.
25save_2D_image       - saves a 2D numpy array of gray shades between 0 and 1 to a PNG file.
26save_2D_movie       - saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files.
27"""
28
29import sys, numpy
30from NeuroTools import check_dependency
31
32
33# Check availability of pylab (essential!)
34if check_dependency('pylab'):
35    import pylab
36if check_dependency('matplotlib'):
37    from matplotlib.figure import Figure
38    from matplotlib.lines import Line2D
39    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
40
41# Check availability of PIL
42PILIMAGEUSE = check_dependency('PIL')
43if PILIMAGEUSE:
44    import PIL.Image as Image
45
46
47
48########################################################
49# UNIVERSAL FUNCTIONS AND CLASSES FOR NORMAL PYLAB USE #
50########################################################
51
52
53
54def get_display(display):
55    """
56    Returns a pylab object with a plot() function to draw the plots.
57   
58    Inputs:
59        display - if True, a new figure is created. Otherwise, if display is a
60                  subplot object, this object is returned.
61    """
62    if display is False:
63        return None
64    elif display is True:
65        pylab.figure()
66        return pylab
67    else:
68        return display
69
70
71
72def progress_bar(progress):
73    """
74    Prints a progress bar to stdout.
75
76    Inputs:
77        progress - a float between 0. and 1.
78       
79    Example:
80        >> progress_bar(0.7)
81            |===================================               |
82    """
83    progressConditionStr = "ERROR: The argument of function NeuroTools.plotting.progress_bar(...) must be a float between 0. and 1.!"
84    assert (type(progress) == float) and (progress >= 0.) and (progress <= 1.), progressConditionStr
85    length = 50
86    filled = int(round(length*progress))
87    print "|" + "=" * filled + " " * (length-filled) + "|\r",
88    sys.stdout.flush()
89
90
91
92def pylab_params(fig_width_pt=246.0,
93                ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default
94                text_fontsize=10, tick_labelsize=8, useTex=False):
95    """
96    Returns a dictionary with a set of parameters that help to nicely format figures.
97    The return object can be used to update the pylab run command parameters dictionary 'pylab.rcParams'.
98
99    Inputs:
100        fig_width_pt   - figure width in points. If you want to use your figure inside LaTeX,
101                         get this value from LaTeX using '\\showthe\\columnwidth'.
102        ratio          - ratio between the height and the width of the figure.
103        text_fontsize  - size of axes and in-pic text fonts.
104        tick_labelsize - size of tick label font.
105        useTex         - enables or disables the use of LaTeX for all labels and texts
106                         (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex).
107    """
108    inches_per_pt = 1.0/72.27               # Convert pt to inch
109    fig_width = fig_width_pt*inches_per_pt  # width in inches
110    fig_height = fig_width*ratio            # height in inches
111    fig_size =  [fig_width,fig_height]
112
113    params = {
114            'axes.labelsize'  : text_fontsize,
115            'text.fontsize'   : text_fontsize,
116            'xtick.labelsize' : tick_labelsize,
117            'ytick.labelsize' : tick_labelsize,
118            'text.usetex'     : useTex,
119            'figure.figsize'  : fig_size}
120           
121    return params
122
123
124
125def set_axis_limits(subplot, xmin, xmax, ymin, ymax):
126    """
127    Defines the axis limits of a plot.
128   
129    Inputs:
130        subplot     - the targeted plot
131        xmin, xmax  - the limits of the x axis
132        ymin, ymax  - the limits of the y axis
133       
134    Example:
135        >> x = range(10)
136        >> y = []
137        >> for i in x: y.append(i*i)
138        >> pylab.plot(x,y)
139        >> plotting.set_axis_limits(pylab, 0., 10., 0., 100.)
140    """
141    if hasattr(subplot, 'xlim'):
142        subplot.xlim(xmin, xmax)
143        subplot.ylim(ymin, ymax)
144    elif hasattr(subplot, 'set_xlim'):
145        subplot.set_xlim(xmin, xmax)
146        subplot.set_ylim(ymin, ymax)
147    else: 
148        raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_axis_limits(...) does not provide limit defining functions.')
149
150
151
152def set_labels(subplot, xlabel, ylabel):
153    """
154    Defines the axis labels of a plot.
155   
156    Inputs:
157        subplot - the targeted plot
158        xlabel  - a string for the x label
159        ylabel  - a string for the y label
160       
161    Example:
162        >> x = range(10)
163        >> y = []
164        >> for i in x: y.append(i*i)
165        >> pylab.plot(x,y)
166        >> plotting.set_labels(pylab, 'x', 'y=x^2')
167    """
168    if hasattr(subplot, 'xlabel'):
169        subplot.xlabel(xlabel)
170        subplot.ylabel(ylabel)
171    elif hasattr(subplot, 'set_xlabel'):
172        subplot.set_xlabel(xlabel)
173        subplot.set_ylabel(ylabel)
174    else: 
175        raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_label(...) does not provide labelling functions.')
176
177
178
179def set_pylab_params(fig_width_pt=246.0,
180                    ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default
181                    text_fontsize=10, tick_labelsize=8, useTex=False):
182    """
183    Updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams'
184    in order to achieve nicely formatted figures.
185
186    Inputs:
187        fig_width_pt   - figure width in points. If you want to use your figure inside LaTeX,
188                         get this value from LaTeX using '\showthe\columnwidth'
189        ratio          - ratio between the height and the width of the figure
190        text_fontsize  - size of axes and in-pic text fonts
191        tick_labelsize - size of tick label font
192        useTex         - enables or disables the use of LaTeX for all labels and texts
193                         (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex)
194    """
195    pylab.rcParams.update(pylab_params(fig_width_pt=fig_width_pt, ratio=ratio, text_fontsize=text_fontsize, \
196        tick_labelsize=tick_labelsize, useTex=useTex))
197
198
199
200####################################################################
201# SPECIAL PLOTTING FUNCTIONS AND CLASSES FOR SPECIFIC REQUIREMENTS #
202####################################################################
203
204
205
206def save_2D_image(mat, filename):
207    """
208    Saves a 2D numpy array of gray shades between 0 and 1 to a PNG file.
209
210    Inputs:
211        mat      - a 2D numpy array of floats between 0 and 1
212        filename - string specifying the filename where to save the data, has to end on '.png'
213   
214    Example:
215        >> import numpy
216        >> a = numpy.random.random([100,100]) # creates a 2D numpy array with random values between 0. and 1.
217        >> save_2D_image(a,'randomarray100x100.png')
218    """
219    assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_image(...) is not supported!"
220    matConditionStr = "ERROR: First argument of function NeuroTools.plotting.imsave(...) must be a 2D numpy array of floats between 0. and 1.!"
221    filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.imsave(...) must be a string ending on \".png\"!"
222    assert (type(mat) == numpy.ndarray) and (mat.ndim == 2) and (mat.min() >= 0.) and (mat.max() <= 1.), matConditionStr
223    assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.png'), filenameConditionStr
224    mode = 'L'
225    # PIL asks for a permuted (col,line) shape coresponding to the natural (x,y) space
226    pilImage = Image.new(mode, (mat.shape[1], mat.shape[0]))
227    data = numpy.floor(numpy.ravel(mat) * 256.)
228    pilImage.putdata(data)
229    pilImage.save(filename)
230
231
232
233def save_2D_movie(frame_list, filename, frame_duration):
234    """
235    Saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files.
236   
237    Inputs:
238        frame_list     - a list of 2D numpy arrays of floats between 0 and 1
239        filename       - string specifying the filename where to save the data, has to end on '.zip'
240        frame_duration - specifier for the duration per frame, will be stored as additional meta-data
241       
242    Example:
243        >> import numpy
244        >> framelist = []
245        >> for i in range(100): framelist.append(numpy.random.random([100,100])) # creates a list of 2D numpy arrays with random values between 0. and 1.
246        >> save_2D_movie(framelist, 'randommovie100x100x100.zip', 0.1)
247    """
248    try:
249        import zipfile
250    except ImportError:
251        raise ImportError("ERROR: Python module zipfile not found! Needed by NeuroTools.plotting.save_2D_movie(...)!")
252    try:
253        import StringIO
254    except ImportError:
255        raise ImportError("ERROR: Python module StringIO not found! Needed by NeuroTools.plotting.save_2D_movie(...)!")
256    assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_movie(...) is not supported!"
257    filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.save_2D_movie(...) must be a string ending on \".zip\"!"
258    assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.zip'), filenameConditionStr
259    zf = zipfile.ZipFile(filename, 'w', zipfile.ZIP_DEFLATED)
260    container = filename[:-4] # remove .zip
261    frame_name_format = "frame%s.%dd.png" % ("%", pylab.ceil(pylab.log10(len(frame_list))))
262    for frame_num, frame in enumerate(frame_list):
263        frame_data = [(p,p,p) for p in frame.flat]
264        im = Image.new('RGB', frame.shape, 'white')
265        im.putdata(frame_data)
266        io = StringIO.StringIO()
267        im.save(io, format='png')
268        pngname = frame_name_format % frame_num
269        arcname = "%s/%s" % (container, pngname)
270        io.seek(0)
271        zf.writestr(arcname, io.read())
272        progress_bar(float(frame_num)/len(frame_list))
273
274    # add 'parameters' and 'frames' files to the zip archive
275    zf.writestr("%s/parameters" % container,
276                'frame_duration = %s' % frame_duration)
277    zf.writestr("%s/frames" % container,
278                '\n'.join(["frame%.3d.png" % i for i in range(len(frame_list))]))
279    zf.close()
280
281
282
283class SimpleMultiplot(object):
284    """
285    A figure consisting of multiple panels, all with the same datatype and
286    the same x-range.
287    """
288    def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None,
289                 scaling=('linear','linear')):
290        self.fig = Figure()
291        self.canvas = FigureCanvas(self.fig)
292        self.axes = []
293        self.all_panels = self.axes
294        self.nrows = nrows
295        self.ncolumns = ncolumns
296        self.n = nrows*ncolumns
297        self._curr_panel = 0
298        self.title = title
299        topmargin = 0.06
300        rightmargin = 0.02
301        bottommargin = 0.1
302        leftmargin=0.1
303        v_panelsep = 0.1*(1 - topmargin - bottommargin)/nrows #0.05
304        h_panelsep = 0.1*(1 - leftmargin - rightmargin)/ncolumns
305        panelheight = (1 - topmargin - bottommargin - (nrows-1)*v_panelsep)/nrows
306        panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*h_panelsep)/ncolumns
307        assert panelheight > 0
308       
309        bottomlist = [bottommargin + i*v_panelsep + i*panelheight for i in range(nrows)]
310        leftlist = [leftmargin + j*h_panelsep + j*panelwidth for j in range(ncolumns)]
311        bottomlist.reverse()
312        for j in range(ncolumns):
313            for i in range(nrows):
314                ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight])
315                self.set_frame(ax,[True,True,False,False])
316                ax.xaxis.tick_bottom()
317                ax.yaxis.tick_left()
318                self.axes.append(ax)
319        if xlabel:
320            self.axes[self.nrows-1].set_xlabel(xlabel)
321        if ylabel:
322            self.fig.text(0.5*leftmargin,0.5,ylabel,
323                          rotation='vertical',
324                          horizontalalignment='center',
325                          verticalalignment='center')
326        if scaling == ("linear","linear"):
327            self.plot_function = "plot"
328        elif scaling == ("log", "log"):
329            self.plot_function = "loglog"
330        elif scaling == ("log", "linear"):
331            self.plot_function = "semilogx"
332        elif scaling == ("linear", "log"):
333            self.plot_function = "semilogy"
334        else:
335            raise Exception("Invalid value for scaling parameter")
336
337    def finalise(self):
338        """Adjustments to be made after all panels have been plotted."""
339        # Turn off tick labels for all x-axes except the bottom one
340        self.fig.text(0.5, 0.99, self.title, horizontalalignment='center',
341                      verticalalignment='top')
342        for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]:
343            ax.xaxis.set_ticklabels([])
344
345    def save(self, filename):
346        """Saves/prints the figure to file.
347       
348        Inputs:
349            filename - string specifying the filename where to save the data
350        """
351        self.finalise()
352        self.canvas.print_figure(filename)
353
354    def next_panel(self):
355        """Changes to next panel within figure."""
356        ax = self.axes[self._curr_panel]
357        self._curr_panel += 1
358        if self._curr_panel >= self.n:
359            self._curr_panel = 0
360        ax.plot1 = getattr(ax, self.plot_function)
361        return ax
362
363    def panel(self, i):
364        """Returns panel i."""
365        ax = self.axes[i]
366        ax.plot1 = getattr(ax, self.plot_function)
367        return ax
368
369    def set_frame(self, ax, boollist, linewidth=2):
370        """
371        Defines frames for the chosen axis.
372
373        Inputs:
374            as        - the targeted axis
375            boollist  - a list
376            linewidth - the limits of the y axis
377        """
378        assert type(boollist) in [list, numpy.ndarray]
379        assert len(boollist) == 4
380        if boollist != [True,True,True,True]:
381            bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k')
382            left   = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
383            top    = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
384            right  = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k')
385            ax.set_frame_on(False)
386            for side,draw in zip([left,bottom,right,top],boollist):
387                if draw:
388                    ax.add_line(side)
Note: See TracBrowser for help on using the browser.