| 1 | """ |
|---|
| 2 | NeuroTools.plotting |
|---|
| 3 | =================== |
|---|
| 4 | |
|---|
| 5 | This module contains a collection of tools for plotting and image processing that |
|---|
| 6 | shall facilitate the generation and handling of NeuroTools data visualizations. |
|---|
| 7 | It utilizes the Matplotlib and the Python Imaging Library (PIL) packages. |
|---|
| 8 | |
|---|
| 9 | |
|---|
| 10 | Classes |
|---|
| 11 | ------- |
|---|
| 12 | |
|---|
| 13 | SimpleMultiplot - object that creates and handles a figure consisting of multiple panels, all with the same datatype and the same x-range. |
|---|
| 14 | |
|---|
| 15 | |
|---|
| 16 | Functions |
|---|
| 17 | --------- |
|---|
| 18 | |
|---|
| 19 | get_display - returns a pylab object with a plot() function to draw the plots. |
|---|
| 20 | progress_bar - prints a progress bar to stdout, filled to the given ratio. |
|---|
| 21 | pylab_params - returns a dictionary with a set of parameters that help to nicely format figures by updating the pylab run command parameters dictionary 'pylab.rcParams'. |
|---|
| 22 | set_axis_limits - defines the axis limits in a plot. |
|---|
| 23 | set_labels - defines the axis labels of a plot. |
|---|
| 24 | set_pylab_params - updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams' in order to achieve nicely formatted figures. |
|---|
| 25 | save_2D_image - saves a 2D numpy array of gray shades between 0 and 1 to a PNG file. |
|---|
| 26 | save_2D_movie - saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files. |
|---|
| 27 | """ |
|---|
| 28 | |
|---|
| 29 | import sys, numpy |
|---|
| 30 | from NeuroTools import check_dependency |
|---|
| 31 | |
|---|
| 32 | |
|---|
| 33 | # Check availability of pylab (essential!) |
|---|
| 34 | if check_dependency('pylab'): |
|---|
| 35 | import pylab |
|---|
| 36 | if check_dependency('matplotlib'): |
|---|
| 37 | from matplotlib.figure import Figure |
|---|
| 38 | from matplotlib.lines import Line2D |
|---|
| 39 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|---|
| 40 | |
|---|
| 41 | # Check availability of PIL |
|---|
| 42 | PILIMAGEUSE = check_dependency('PIL') |
|---|
| 43 | if PILIMAGEUSE: |
|---|
| 44 | import PIL.Image as Image |
|---|
| 45 | |
|---|
| 46 | |
|---|
| 47 | |
|---|
| 48 | ######################################################## |
|---|
| 49 | # UNIVERSAL FUNCTIONS AND CLASSES FOR NORMAL PYLAB USE # |
|---|
| 50 | ######################################################## |
|---|
| 51 | |
|---|
| 52 | |
|---|
| 53 | |
|---|
| 54 | def get_display(display): |
|---|
| 55 | """ |
|---|
| 56 | Returns a pylab object with a plot() function to draw the plots. |
|---|
| 57 | |
|---|
| 58 | Inputs: |
|---|
| 59 | display - if True, a new figure is created. Otherwise, if display is a |
|---|
| 60 | subplot object, this object is returned. |
|---|
| 61 | """ |
|---|
| 62 | if display is False: |
|---|
| 63 | return None |
|---|
| 64 | elif display is True: |
|---|
| 65 | pylab.figure() |
|---|
| 66 | return pylab |
|---|
| 67 | else: |
|---|
| 68 | return display |
|---|
| 69 | |
|---|
| 70 | |
|---|
| 71 | |
|---|
| 72 | def progress_bar(progress): |
|---|
| 73 | """ |
|---|
| 74 | Prints a progress bar to stdout. |
|---|
| 75 | |
|---|
| 76 | Inputs: |
|---|
| 77 | progress - a float between 0. and 1. |
|---|
| 78 | |
|---|
| 79 | Example: |
|---|
| 80 | >> progress_bar(0.7) |
|---|
| 81 | |=================================== | |
|---|
| 82 | """ |
|---|
| 83 | progressConditionStr = "ERROR: The argument of function NeuroTools.plotting.progress_bar(...) must be a float between 0. and 1.!" |
|---|
| 84 | assert (type(progress) == float) and (progress >= 0.) and (progress <= 1.), progressConditionStr |
|---|
| 85 | length = 50 |
|---|
| 86 | filled = int(round(length*progress)) |
|---|
| 87 | print "|" + "=" * filled + " " * (length-filled) + "|\r", |
|---|
| 88 | sys.stdout.flush() |
|---|
| 89 | |
|---|
| 90 | |
|---|
| 91 | |
|---|
| 92 | def pylab_params(fig_width_pt=246.0, |
|---|
| 93 | ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default |
|---|
| 94 | text_fontsize=10, tick_labelsize=8, useTex=False): |
|---|
| 95 | """ |
|---|
| 96 | Returns a dictionary with a set of parameters that help to nicely format figures. |
|---|
| 97 | The return object can be used to update the pylab run command parameters dictionary 'pylab.rcParams'. |
|---|
| 98 | |
|---|
| 99 | Inputs: |
|---|
| 100 | fig_width_pt - figure width in points. If you want to use your figure inside LaTeX, |
|---|
| 101 | get this value from LaTeX using '\\showthe\\columnwidth'. |
|---|
| 102 | ratio - ratio between the height and the width of the figure. |
|---|
| 103 | text_fontsize - size of axes and in-pic text fonts. |
|---|
| 104 | tick_labelsize - size of tick label font. |
|---|
| 105 | useTex - enables or disables the use of LaTeX for all labels and texts |
|---|
| 106 | (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex). |
|---|
| 107 | """ |
|---|
| 108 | inches_per_pt = 1.0/72.27 # Convert pt to inch |
|---|
| 109 | fig_width = fig_width_pt*inches_per_pt # width in inches |
|---|
| 110 | fig_height = fig_width*ratio # height in inches |
|---|
| 111 | fig_size = [fig_width,fig_height] |
|---|
| 112 | |
|---|
| 113 | params = { |
|---|
| 114 | 'axes.labelsize' : text_fontsize, |
|---|
| 115 | 'text.fontsize' : text_fontsize, |
|---|
| 116 | 'xtick.labelsize' : tick_labelsize, |
|---|
| 117 | 'ytick.labelsize' : tick_labelsize, |
|---|
| 118 | 'text.usetex' : useTex, |
|---|
| 119 | 'figure.figsize' : fig_size} |
|---|
| 120 | |
|---|
| 121 | return params |
|---|
| 122 | |
|---|
| 123 | |
|---|
| 124 | |
|---|
| 125 | def set_axis_limits(subplot, xmin, xmax, ymin, ymax): |
|---|
| 126 | """ |
|---|
| 127 | Defines the axis limits of a plot. |
|---|
| 128 | |
|---|
| 129 | Inputs: |
|---|
| 130 | subplot - the targeted plot |
|---|
| 131 | xmin, xmax - the limits of the x axis |
|---|
| 132 | ymin, ymax - the limits of the y axis |
|---|
| 133 | |
|---|
| 134 | Example: |
|---|
| 135 | >> x = range(10) |
|---|
| 136 | >> y = [] |
|---|
| 137 | >> for i in x: y.append(i*i) |
|---|
| 138 | >> pylab.plot(x,y) |
|---|
| 139 | >> plotting.set_axis_limits(pylab, 0., 10., 0., 100.) |
|---|
| 140 | """ |
|---|
| 141 | if hasattr(subplot, 'xlim'): |
|---|
| 142 | subplot.xlim(xmin, xmax) |
|---|
| 143 | subplot.ylim(ymin, ymax) |
|---|
| 144 | elif hasattr(subplot, 'set_xlim'): |
|---|
| 145 | subplot.set_xlim(xmin, xmax) |
|---|
| 146 | subplot.set_ylim(ymin, ymax) |
|---|
| 147 | else: |
|---|
| 148 | raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_axis_limits(...) does not provide limit defining functions.') |
|---|
| 149 | |
|---|
| 150 | |
|---|
| 151 | |
|---|
| 152 | def set_labels(subplot, xlabel, ylabel): |
|---|
| 153 | """ |
|---|
| 154 | Defines the axis labels of a plot. |
|---|
| 155 | |
|---|
| 156 | Inputs: |
|---|
| 157 | subplot - the targeted plot |
|---|
| 158 | xlabel - a string for the x label |
|---|
| 159 | ylabel - a string for the y label |
|---|
| 160 | |
|---|
| 161 | Example: |
|---|
| 162 | >> x = range(10) |
|---|
| 163 | >> y = [] |
|---|
| 164 | >> for i in x: y.append(i*i) |
|---|
| 165 | >> pylab.plot(x,y) |
|---|
| 166 | >> plotting.set_labels(pylab, 'x', 'y=x^2') |
|---|
| 167 | """ |
|---|
| 168 | if hasattr(subplot, 'xlabel'): |
|---|
| 169 | subplot.xlabel(xlabel) |
|---|
| 170 | subplot.ylabel(ylabel) |
|---|
| 171 | elif hasattr(subplot, 'set_xlabel'): |
|---|
| 172 | subplot.set_xlabel(xlabel) |
|---|
| 173 | subplot.set_ylabel(ylabel) |
|---|
| 174 | else: |
|---|
| 175 | raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_label(...) does not provide labelling functions.') |
|---|
| 176 | |
|---|
| 177 | |
|---|
| 178 | |
|---|
| 179 | def set_pylab_params(fig_width_pt=246.0, |
|---|
| 180 | ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default |
|---|
| 181 | text_fontsize=10, tick_labelsize=8, useTex=False): |
|---|
| 182 | """ |
|---|
| 183 | Updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams' |
|---|
| 184 | in order to achieve nicely formatted figures. |
|---|
| 185 | |
|---|
| 186 | Inputs: |
|---|
| 187 | fig_width_pt - figure width in points. If you want to use your figure inside LaTeX, |
|---|
| 188 | get this value from LaTeX using '\showthe\columnwidth' |
|---|
| 189 | ratio - ratio between the height and the width of the figure |
|---|
| 190 | text_fontsize - size of axes and in-pic text fonts |
|---|
| 191 | tick_labelsize - size of tick label font |
|---|
| 192 | useTex - enables or disables the use of LaTeX for all labels and texts |
|---|
| 193 | (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex) |
|---|
| 194 | """ |
|---|
| 195 | pylab.rcParams.update(pylab_params(fig_width_pt=fig_width_pt, ratio=ratio, text_fontsize=text_fontsize, \ |
|---|
| 196 | tick_labelsize=tick_labelsize, useTex=useTex)) |
|---|
| 197 | |
|---|
| 198 | |
|---|
| 199 | |
|---|
| 200 | #################################################################### |
|---|
| 201 | # SPECIAL PLOTTING FUNCTIONS AND CLASSES FOR SPECIFIC REQUIREMENTS # |
|---|
| 202 | #################################################################### |
|---|
| 203 | |
|---|
| 204 | |
|---|
| 205 | |
|---|
| 206 | def save_2D_image(mat, filename): |
|---|
| 207 | """ |
|---|
| 208 | Saves a 2D numpy array of gray shades between 0 and 1 to a PNG file. |
|---|
| 209 | |
|---|
| 210 | Inputs: |
|---|
| 211 | mat - a 2D numpy array of floats between 0 and 1 |
|---|
| 212 | filename - string specifying the filename where to save the data, has to end on '.png' |
|---|
| 213 | |
|---|
| 214 | Example: |
|---|
| 215 | >> import numpy |
|---|
| 216 | >> a = numpy.random.random([100,100]) # creates a 2D numpy array with random values between 0. and 1. |
|---|
| 217 | >> save_2D_image(a,'randomarray100x100.png') |
|---|
| 218 | """ |
|---|
| 219 | assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_image(...) is not supported!" |
|---|
| 220 | matConditionStr = "ERROR: First argument of function NeuroTools.plotting.imsave(...) must be a 2D numpy array of floats between 0. and 1.!" |
|---|
| 221 | filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.imsave(...) must be a string ending on \".png\"!" |
|---|
| 222 | assert (type(mat) == numpy.ndarray) and (mat.ndim == 2) and (mat.min() >= 0.) and (mat.max() <= 1.), matConditionStr |
|---|
| 223 | assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.png'), filenameConditionStr |
|---|
| 224 | mode = 'L' |
|---|
| 225 | # PIL asks for a permuted (col,line) shape coresponding to the natural (x,y) space |
|---|
| 226 | pilImage = Image.new(mode, (mat.shape[1], mat.shape[0])) |
|---|
| 227 | data = numpy.floor(numpy.ravel(mat) * 256.) |
|---|
| 228 | pilImage.putdata(data) |
|---|
| 229 | pilImage.save(filename) |
|---|
| 230 | |
|---|
| 231 | |
|---|
| 232 | |
|---|
| 233 | def save_2D_movie(frame_list, filename, frame_duration): |
|---|
| 234 | """ |
|---|
| 235 | Saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files. |
|---|
| 236 | |
|---|
| 237 | Inputs: |
|---|
| 238 | frame_list - a list of 2D numpy arrays of floats between 0 and 1 |
|---|
| 239 | filename - string specifying the filename where to save the data, has to end on '.zip' |
|---|
| 240 | frame_duration - specifier for the duration per frame, will be stored as additional meta-data |
|---|
| 241 | |
|---|
| 242 | Example: |
|---|
| 243 | >> import numpy |
|---|
| 244 | >> framelist = [] |
|---|
| 245 | >> for i in range(100): framelist.append(numpy.random.random([100,100])) # creates a list of 2D numpy arrays with random values between 0. and 1. |
|---|
| 246 | >> save_2D_movie(framelist, 'randommovie100x100x100.zip', 0.1) |
|---|
| 247 | """ |
|---|
| 248 | try: |
|---|
| 249 | import zipfile |
|---|
| 250 | except ImportError: |
|---|
| 251 | raise ImportError("ERROR: Python module zipfile not found! Needed by NeuroTools.plotting.save_2D_movie(...)!") |
|---|
| 252 | try: |
|---|
| 253 | import StringIO |
|---|
| 254 | except ImportError: |
|---|
| 255 | raise ImportError("ERROR: Python module StringIO not found! Needed by NeuroTools.plotting.save_2D_movie(...)!") |
|---|
| 256 | assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_movie(...) is not supported!" |
|---|
| 257 | filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.save_2D_movie(...) must be a string ending on \".zip\"!" |
|---|
| 258 | assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.zip'), filenameConditionStr |
|---|
| 259 | zf = zipfile.ZipFile(filename, 'w', zipfile.ZIP_DEFLATED) |
|---|
| 260 | container = filename[:-4] # remove .zip |
|---|
| 261 | frame_name_format = "frame%s.%dd.png" % ("%", pylab.ceil(pylab.log10(len(frame_list)))) |
|---|
| 262 | for frame_num, frame in enumerate(frame_list): |
|---|
| 263 | frame_data = [(p,p,p) for p in frame.flat] |
|---|
| 264 | im = Image.new('RGB', frame.shape, 'white') |
|---|
| 265 | im.putdata(frame_data) |
|---|
| 266 | io = StringIO.StringIO() |
|---|
| 267 | im.save(io, format='png') |
|---|
| 268 | pngname = frame_name_format % frame_num |
|---|
| 269 | arcname = "%s/%s" % (container, pngname) |
|---|
| 270 | io.seek(0) |
|---|
| 271 | zf.writestr(arcname, io.read()) |
|---|
| 272 | progress_bar(float(frame_num)/len(frame_list)) |
|---|
| 273 | |
|---|
| 274 | # add 'parameters' and 'frames' files to the zip archive |
|---|
| 275 | zf.writestr("%s/parameters" % container, |
|---|
| 276 | 'frame_duration = %s' % frame_duration) |
|---|
| 277 | zf.writestr("%s/frames" % container, |
|---|
| 278 | '\n'.join(["frame%.3d.png" % i for i in range(len(frame_list))])) |
|---|
| 279 | zf.close() |
|---|
| 280 | |
|---|
| 281 | |
|---|
| 282 | |
|---|
| 283 | class SimpleMultiplot(object): |
|---|
| 284 | """ |
|---|
| 285 | A figure consisting of multiple panels, all with the same datatype and |
|---|
| 286 | the same x-range. |
|---|
| 287 | """ |
|---|
| 288 | def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None, |
|---|
| 289 | scaling=('linear','linear')): |
|---|
| 290 | self.fig = Figure() |
|---|
| 291 | self.canvas = FigureCanvas(self.fig) |
|---|
| 292 | self.axes = [] |
|---|
| 293 | self.all_panels = self.axes |
|---|
| 294 | self.nrows = nrows |
|---|
| 295 | self.ncolumns = ncolumns |
|---|
| 296 | self.n = nrows*ncolumns |
|---|
| 297 | self._curr_panel = 0 |
|---|
| 298 | self.title = title |
|---|
| 299 | topmargin = 0.06 |
|---|
| 300 | rightmargin = 0.02 |
|---|
| 301 | bottommargin = 0.1 |
|---|
| 302 | leftmargin=0.1 |
|---|
| 303 | v_panelsep = 0.1*(1 - topmargin - bottommargin)/nrows #0.05 |
|---|
| 304 | h_panelsep = 0.1*(1 - leftmargin - rightmargin)/ncolumns |
|---|
| 305 | panelheight = (1 - topmargin - bottommargin - (nrows-1)*v_panelsep)/nrows |
|---|
| 306 | panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*h_panelsep)/ncolumns |
|---|
| 307 | assert panelheight > 0 |
|---|
| 308 | |
|---|
| 309 | bottomlist = [bottommargin + i*v_panelsep + i*panelheight for i in range(nrows)] |
|---|
| 310 | leftlist = [leftmargin + j*h_panelsep + j*panelwidth for j in range(ncolumns)] |
|---|
| 311 | bottomlist.reverse() |
|---|
| 312 | for j in range(ncolumns): |
|---|
| 313 | for i in range(nrows): |
|---|
| 314 | ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight]) |
|---|
| 315 | self.set_frame(ax,[True,True,False,False]) |
|---|
| 316 | ax.xaxis.tick_bottom() |
|---|
| 317 | ax.yaxis.tick_left() |
|---|
| 318 | self.axes.append(ax) |
|---|
| 319 | if xlabel: |
|---|
| 320 | self.axes[self.nrows-1].set_xlabel(xlabel) |
|---|
| 321 | if ylabel: |
|---|
| 322 | self.fig.text(0.5*leftmargin,0.5,ylabel, |
|---|
| 323 | rotation='vertical', |
|---|
| 324 | horizontalalignment='center', |
|---|
| 325 | verticalalignment='center') |
|---|
| 326 | if scaling == ("linear","linear"): |
|---|
| 327 | self.plot_function = "plot" |
|---|
| 328 | elif scaling == ("log", "log"): |
|---|
| 329 | self.plot_function = "loglog" |
|---|
| 330 | elif scaling == ("log", "linear"): |
|---|
| 331 | self.plot_function = "semilogx" |
|---|
| 332 | elif scaling == ("linear", "log"): |
|---|
| 333 | self.plot_function = "semilogy" |
|---|
| 334 | else: |
|---|
| 335 | raise Exception("Invalid value for scaling parameter") |
|---|
| 336 | |
|---|
| 337 | def finalise(self): |
|---|
| 338 | """Adjustments to be made after all panels have been plotted.""" |
|---|
| 339 | # Turn off tick labels for all x-axes except the bottom one |
|---|
| 340 | self.fig.text(0.5, 0.99, self.title, horizontalalignment='center', |
|---|
| 341 | verticalalignment='top') |
|---|
| 342 | for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]: |
|---|
| 343 | ax.xaxis.set_ticklabels([]) |
|---|
| 344 | |
|---|
| 345 | def save(self, filename): |
|---|
| 346 | """Saves/prints the figure to file. |
|---|
| 347 | |
|---|
| 348 | Inputs: |
|---|
| 349 | filename - string specifying the filename where to save the data |
|---|
| 350 | """ |
|---|
| 351 | self.finalise() |
|---|
| 352 | self.canvas.print_figure(filename) |
|---|
| 353 | |
|---|
| 354 | def next_panel(self): |
|---|
| 355 | """Changes to next panel within figure.""" |
|---|
| 356 | ax = self.axes[self._curr_panel] |
|---|
| 357 | self._curr_panel += 1 |
|---|
| 358 | if self._curr_panel >= self.n: |
|---|
| 359 | self._curr_panel = 0 |
|---|
| 360 | ax.plot1 = getattr(ax, self.plot_function) |
|---|
| 361 | return ax |
|---|
| 362 | |
|---|
| 363 | def panel(self, i): |
|---|
| 364 | """Returns panel i.""" |
|---|
| 365 | ax = self.axes[i] |
|---|
| 366 | ax.plot1 = getattr(ax, self.plot_function) |
|---|
| 367 | return ax |
|---|
| 368 | |
|---|
| 369 | def set_frame(self, ax, boollist, linewidth=2): |
|---|
| 370 | """ |
|---|
| 371 | Defines frames for the chosen axis. |
|---|
| 372 | |
|---|
| 373 | Inputs: |
|---|
| 374 | as - the targeted axis |
|---|
| 375 | boollist - a list |
|---|
| 376 | linewidth - the limits of the y axis |
|---|
| 377 | """ |
|---|
| 378 | assert type(boollist) in [list, numpy.ndarray] |
|---|
| 379 | assert len(boollist) == 4 |
|---|
| 380 | if boollist != [True,True,True,True]: |
|---|
| 381 | bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| 382 | left = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| 383 | top = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| 384 | right = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') |
|---|
| 385 | ax.set_frame_on(False) |
|---|
| 386 | for side,draw in zip([left,bottom,right,top],boollist): |
|---|
| 387 | if draw: |
|---|
| 388 | ax.add_line(side) |
|---|