Source code for pyanno.plots.matrix_plot

# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)

"""View of 2D matrix based on Chaco."""

from chaco.array_plot_data import ArrayPlotData
from chaco.color_bar import ColorBar
from chaco.data_range_1d import DataRange1D
from chaco.default_colormaps import jet, Reds, RdBu, reverse
from chaco.linear_mapper import LinearMapper
from chaco.plot import Plot
from chaco.plot_containers import HPlotContainer
from enable.component_editor import ComponentEditor

from traits.trait_numeric import Array
from traits.trait_types import Str, Instance, Float
from traitsui.item import Item

import numpy as np
from pyanno.plots.plots_superclass import PyannoPlotContainer


class MatrixPlot(PyannoPlotContainer):

    # data to be displayed
    matrix = Array

    #### plot-related traits
    colormap_low = Float(None)
    colormap_high = Float(None)
    origin = Str('top left')

    matrix_plot_container = Instance(HPlotContainer)

    def _create_colormap(self):
        if self.colormap_low is None:
            self.colormap_low = self.matrix.min()

        if self.colormap_high is None:
            self.colormap_high = self.matrix.max()

        if self.colormap_low >= 0.0:
            colormap_factory = Reds
        else:
            colormap_factory = reverse(RdBu)

        colormap = colormap_factory(
            DataRange1D(low=self.colormap_low, high=self.colormap_high)
        )

        return colormap


    def _matrix_plot_container_default(self):
        matrix = np.nan_to_num(self.matrix)
        width = matrix.shape[0]

        # create a plot data object and give it this data
        plot_data = ArrayPlotData()
        plot_data.set_data("values", matrix)

        # create the plot
        plot = Plot(plot_data, origin=self.origin)

        img_plot = plot.img_plot("values",
                                 interpolation='nearest',

                                 xbounds=(0, width),
                                 ybounds=(0, width),
                                 colormap=self._create_colormap())[0]

        #### fix axes
        self._remove_grid_and_axes(plot)

        axis = self._create_increment_one_axis(plot, 0.5, width, 'bottom')
        self._add_value_axis(plot, axis)

        axis = self._create_increment_one_axis(plot, 0.5, width, 'left')
        self._add_index_axis(plot, axis)

        #### tweak plot attributes
        self._set_title(plot)
        plot.aspect_ratio = 1.
        # padding [left, right, up, down]
        plot.padding = [0, 0, 25, 25]

        # create the colorbar, handing in the appropriate range and colormap
        colormap = img_plot.color_mapper
        colorbar = ColorBar(index_mapper=LinearMapper(range=colormap.range),
                            color_mapper=colormap,
                            plot=img_plot,
                            orientation='v',
                            resizable='v',
                            width=20,
                            padding=[0, 20, 0, 0])
        colorbar.padding_top = plot.padding_top
        colorbar.padding_bottom = plot.padding_bottom

        # create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(plot)
        container.add(colorbar)
        container.bgcolor = 0xFFFFFF

        self.decorate_plot(container, self.matrix)
        return container


    resizable_plot_item = Item(
        'matrix_plot_container',
        editor=ComponentEditor(),
        resizable=True,
        show_label=False,
        width = 600,
        height = 400
        )

    traits_plot_item = Instance(Item)
    def _traits_plot_item_default(self):
        return Item(
                    'matrix_plot_container',
                    editor=ComponentEditor(),
                    resizable=False,
                    show_label=False,
                    height=-200,
                    width=-200
                    )


[docs]def plot_square_matrix(matrix, **kwargs): """Plot a square matrix using a color code. Arguments --------- matrix : ndarray, shape = (n, n) Square array of numberical values. kwargs : dictionary Additional keyword arguments passed to the plot. The argument `title` sets the title of the plot. The arguments `colormap_low` and `colormap_high` set the lowest and highest value for the color map. """ matrix_view = MatrixPlot(matrix=matrix, **kwargs) matrix_view.edit_traits(view='resizable_view') return matrix_view #### Testing and debugging ####################################################
def main(): """ Entry point for standalone testing/debugging. """ matrix = np.eye(5) matrix_view = plot_square_matrix(matrix, title='Debug plot_matrix') return matrix_view if __name__ == '__main__': mv = main()

Table Of Contents