# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)
from chaco.array_plot_data import ArrayPlotData
from chaco.color_bar import ColorBar
from chaco.data_range_1d import DataRange1D
from chaco.default_colormaps import Reds
from chaco.linear_mapper import LinearMapper
from chaco.plot import Plot
from chaco.plot_containers import HPlotContainer
from chaco.tools.pan_tool import PanTool
from enable.component_editor import ComponentEditor
from traits.trait_numeric import Array
from traits.trait_types import Str, Instance, Float
from traitsui.group import VGroup, HGroup
from traitsui.include import Include
from traitsui.item import Item, Spring
from traitsui.view import View
from pyanno.plots.plots_superclass import PyannoPlotContainer
import numpy as np
from pyanno.ui.appbase.wx_utils import is_display_small
class PosteriorPlot(PyannoPlotContainer):
# data to be displayed
posterior = Array
### plot-related traits
plot_width = Float
def _plot_width_default(self):
return 450 if is_display_small() else 500
plot_height = Float
def _plot_height_default(self):
return 600 if is_display_small() else 750
colormap_low = Float(0.0)
colormap_high = Float(1.0)
origin = Str('top left')
plot_container = Instance(HPlotContainer)
plot_posterior = Instance(Plot)
def _create_colormap(self):
if self.colormap_low is None:
self.colormap_low = self.posterior.min()
if self.colormap_high is None:
self.colormap_high = self.posterior.max()
colormap = Reds(DataRange1D(low=self.colormap_low,
high=self.colormap_high))
return colormap
def _plot_container_default(self):
data = self.posterior
nannotations, nclasses = data.shape
# create a plot data object
plot_data = ArrayPlotData()
plot_data.set_data("values", data)
# create the plot
plot = Plot(plot_data, origin=self.origin)
img_plot = plot.img_plot("values",
interpolation='nearest',
xbounds=(0, nclasses),
ybounds=(0, nannotations),
colormap=self._create_colormap())[0]
ndisp = 55
img_plot.y_mapper.range.high = ndisp
img_plot.y_mapper.domain_limits=((0, nannotations))
self._set_title(plot)
plot.padding_top = 80
# create x axis for labels
label_axis = self._create_increment_one_axis(plot, 0.5, nclasses, 'top')
label_axis.title = 'classes'
self._add_index_axis(plot, label_axis)
plot.y_axis.title = 'items'
# tweak plot aspect
goal_aspect_ratio = 2.0
plot_width = (goal_aspect_ratio * self.plot_height
* nclasses / ndisp)
self.plot_width = min(max(plot_width, 200), 400)
plot.aspect_ratio = self.plot_width / self.plot_height
# add colorbar
colormap = img_plot.color_mapper
colorbar = ColorBar(index_mapper = LinearMapper(range=colormap.range),
color_mapper = colormap,
plot = img_plot,
orientation = 'v',
resizable = '',
width = 15,
height = 250)
colorbar.padding_top = plot.padding_top
colorbar.padding_bottom = int(self.plot_height - colorbar.height -
plot.padding_top)
colorbar.padding_left = 0
colorbar.padding_right = 30
# create a container to position the plot and the colorbar side-by-side
container = HPlotContainer(use_backbuffer=True)
container.add(plot)
container.add(colorbar)
container.bgcolor = 0xFFFFFF # light gray: 0xEEEEEE
# add pan tools
img_plot.tools.append(PanTool(img_plot, constrain=True,
constrain_direction="y", speed=7.))
self.decorate_plot(container, self.posterior)
self.plot_posterior = plot
return container
def add_markings(self, mark_classes, mark_name, marker_shape,
delta_x, delta_y, marker_size=5, line_width=1.,
marker_color='white'):
plot = self.plot_posterior
nannotations = plot.data.arrays['values'].shape[0]
y_name = mark_name + '_y'
x_name = mark_name + '_x'
y_values = np.arange(nannotations) + delta_y + 0.5
x_values = mark_classes.astype(float) + delta_x + 0.5
plot.data.set_data(y_name, y_values)
plot.data.set_data(x_name, x_values)
plot.plot((x_name, y_name), type='scatter', name=mark_name,
marker=marker_shape, marker_size=marker_size,
color='transparent',
outline_color=marker_color, line_width=line_width)
def remove_markings(self, mark_name):
self.plot_posterior.delplot(mark_name)
def _create_resizable_view(self):
# resizable_view factory, as I need to compute the height of the plot
# from the number of annotations, and I couldn't find any other way to
# do that
# "touch" posterior_plot to have it initialize
self.plot_container
if is_display_small():
height = 760
else:
height = 800
resizable_plot_item = (
Item(
'plot_container',
editor=ComponentEditor(),
resizable=True,
show_label=False,
width = self.plot_width,
height = self.plot_height,
)
)
resizable_view = View(
VGroup(
Include('instructions_group'),
resizable_plot_item,
),
width = 450,
height = height,
resizable = True
)
return resizable_view
def traits_view(self):
return self._create_resizable_view()
pan_instructions = Str
def _pan_instructions_default(self):
return 'Left-click and drag to navigate items'
instructions_group = VGroup(
HGroup(
Spring(),
Item('instructions', style='readonly', show_label=False),
Spring()
),
HGroup(
Spring(),
Item('pan_instructions', style='readonly', show_label=False),
Spring()
)
)
[docs]def plot_posterior(posterior, show_maximum=False, **kwargs):
"""Display a plot of the posterior distribution over classes.
This function is used together with the `infer_labels` method offered by
all models, e.g.: ::
>>> from pyanno.models import ModelB
>>> from pyanno.plots import plot_posterior
>>> # create a new model with 3 classes and 6 annotators
>>> model = ModelB.create_initial_state(3, 6)
>>> annotations = model.generate_annotations(100)
>>> # compute the posterior distribution over class labels
>>> posterior = model.infer_labels(annotations)
>>> # plot the distribution in a window
>>> plot_posterior(posterior)
Arguments
---------
posterior : ndarray, shape=(n_annotations, n_classes)
posterior[i,:] is the posterior distribution over classes for the
i-th annotation.
show_maximum : bool
if True, indicate the position of the maxima with white circles
kwargs : dictionary
Additional keyword arguments passed to the plot. The argument 'title'
sets the title of the plot.
"""
post_view = PosteriorPlot(posterior=posterior, **kwargs)
resizable_view = post_view._create_resizable_view()
post_view.edit_traits(view=resizable_view)
if show_maximum:
maximum = posterior.argmax(1)
post_view.add_markings(maximum, 'maximum',
'circle', 0., 0., marker_size=4,
marker_color='blue')
return post_view
#### Testing and debugging ####################################################
def main():
""" Entry point for standalone testing/debugging. """
import numpy as np
matrix = np.random.random(size=(20000, 5))
matrix = matrix / matrix.sum(1)[:,None]
matrix[0,0] = 1.
matrix_view = plot_posterior(matrix, show_maximum=True, title='TEST')
matrix_view.configure_traits()
return matrix_view
if __name__ == '__main__':
mv = main()