Source code for pyanno.plots.theta_tensor_plot

# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)

"""TraitsUI view of the Theta tensor in Model B, and its samples."""

from chaco.array_plot_data import ArrayPlotData
from chaco.data_range_2d import DataRange2D
from chaco.legend import Legend
from chaco.plot import Plot
from chaco.plot_containers import HPlotContainer, VPlotContainer
from chaco.scales.scales import FixedScale
from chaco.scales_tick_generator import ScalesTickGenerator
from chaco.tools.legend_tool import LegendTool

from enable.component_editor import ComponentEditor

from traits.trait_numeric import Array
from traits.trait_types import Instance, Str, Range, Button, Int, Any
from traitsui.item import Item
from pyanno.plots.plot_tools import get_class_color

from pyanno.plots.plots_superclass import PyannoPlotContainer


import numpy as np
from pyanno.ui.appbase.wx_utils import is_display_small


def sigmoid(x):
    return 1./(1.+np.exp(-x))


class ThetaTensorPlot(PyannoPlotContainer):

    # reference to the theta tensor for one annotator
    theta = Array

    # reference to an array of samples for theta for one annotator
    theta_samples = Any

    # index of the annotator
    annotator_idx = Int

    # chaco plot of the tensor
    theta_plot = Any


    def _label_name(self, k):
        """Return a name for the data with index `k`."""
        nclasses = self.theta.shape[0]
        ndigits = int(np.ceil(np.log10(nclasses)))
        format_str = 'theta[{{}},{{:{}d}},:]'.format(ndigits)

        return format_str.format(self.annotator_idx,k)


    def _plot_samples(self, plot, plot_data):
        nclasses = self.theta.shape[0]
        nsamples = self.theta_samples.shape[0]

        for k in range(nclasses):
            samples = np.sort(self.theta_samples[:,k,:], axis=0)
            perc5 = samples[int(nsamples*0.05),:]
            perc95 = samples[int(nsamples*0.95),:]
            avg = samples.mean(0)

            # build polygon
            index_name = self._label_name(k) + '_confint_index'
            value_name = self._label_name(k) + '_confint_value'
            index_coord = []
            value_coord = []
            # bottom part
            for i in range(nclasses):
                index_coord.append(i)
                value_coord.append(perc5[i])
            # top part
            for i in range(nclasses-1, -1, -1):
                index_coord.append(i)
                value_coord.append(perc95[i])

            plot_data.set_data(index_name, np.array(index_coord,
                                                    dtype=float))
            plot_data.set_data(value_name, np.array(value_coord,
                                                    dtype=float))

            # make color lighter and more transparent
            color = get_class_color(k)
            for i in range(3):
                color[i] = min(1.0, sigmoid(color[i]*5.))
            color[-1] = 0.3

            plot.plot(
                (index_name, value_name),
                type = 'polygon',
                face_color = color,
                edge_color = 'black',
                edge_width = 0.5
            )

            # add average
            avg_name = self._label_name(k) + '_avg_value'
            plot_data.set_data(avg_name, avg)
            plot.plot(
                ('classes', avg_name),
                color = get_class_color(k),
                line_style = 'dash'
            )


    def _plot_theta_values(self, plot, plot_data):
        theta = self.theta
        nclasses = theta.shape[0]

        data_names = ['classes']
        for k in range(nclasses):
            name = self._label_name(k)
            plot_data.set_data(name, theta[k,:])
            data_names.append(name)

        plots = {}
        for k in range(nclasses):
            name = self._label_name(k)
            line_plot = plot.plot(
                ['classes', name],
                line_width=2.,
                color = get_class_color(k),
                name=name
            )
            plots[name] = line_plot

        return plots


    def _theta_plot_default(self):

        theta = self.theta
        nclasses = theta.shape[0]

        # create a plot data object and give it this data
        plot_data = ArrayPlotData()

        plot_data.set_data('classes', range(nclasses))

        # create the plot
        plot = Plot(plot_data)

        # --- plot theta samples
        if self.theta_samples is not None:
            self._plot_samples(plot, plot_data)

        # --- plot values of theta
        plots = self._plot_theta_values(plot, plot_data)

        # --- adjust plot appearance

        plot.aspect_ratio = 1.6 if is_display_small() else 1.7

        # adjust axis bounds
        y_high = theta.max()
        if self.theta_samples is not None:
            y_high = max(y_high, self.theta_samples.max())

        plot.range2d = DataRange2D(
            low  = (-0.2, 0.0),
            high = (nclasses-1+0.2, y_high*1.1)
        )

        # create new horizontal axis
        label_axis = self._create_increment_one_axis(
            plot, 0., nclasses, 'bottom')
        label_axis.title = 'True classes'
        self._add_index_axis(plot, label_axis)

        # label vertical axis
        plot.value_axis.title = 'Probability'

        # add legend
        legend = Legend(component=plot, plots=plots,
                        align="ur", border_padding=10)
        legend.tools.append(LegendTool(legend, drag_button="left"))
        legend.padding_right = -100
        plot.overlays.append(legend)

        container = VPlotContainer(width=plot.width + 100, halign='left')
        plot.padding_bottom = 50
        plot.padding_top = 10
        plot.padding_left = 0
        container.add(plot)
        container.bgcolor = 0xFFFFFF

        self.decorate_plot(container, theta)

        return container


    #### View definition #####################################################

    resizable_plot_item = Item(
        'theta_plot',
        editor=ComponentEditor(),
        resizable=True,
        show_label=False,
        height=-300
        )

    traits_plot_item = Instance(Item)

    def _traits_plot_item_default(self):
        height = -200 if is_display_small() else -250
        return Item(
                    'theta_plot',
                    editor=ComponentEditor(),
                    resizable=False,
                    show_label=False,
                    height=height,
                    )


[docs]def plot_theta_tensor(modelB, annotator_idx, theta_samples=None, **kwargs): """Display a plot of model B's accuracy tensor, theta. The tensor theta[annotator_idx,:,:] is shown for one annotator as a set of line plots, each depicting the distribution theta[annotator_idx,k,:] = P(annotator_idx outputs : | real class is k). Arguments --------- modelB : ModelB instance An instance of ModelB. annotator_idx : int Index of the annotator for which the parameters are displayed. theta_samples : ndarray, shape = (n_samples x n_annotators x n_classes x n_classes) Array of samples over the posterior of theta. kwargs : dictionary Additional keyword arguments passed to the plot. The argument 'title' sets the title of the plot. Returns ------- theta_view : ThetaTensorPlot instance Reference to the plot. """ samples = None if theta_samples is not None: samples = theta_samples[:,annotator_idx,:,:] theta_view = ThetaTensorPlot( theta = modelB.theta[annotator_idx,:,:], annotator_idx = annotator_idx, theta_samples = samples ) theta_view.configure_traits(view='resizable_view') return theta_view #### Testing and debugging ####################################################
def main(): """ Entry point for standalone testing/debugging. """ from pyanno.models import ModelB model = ModelB.create_initial_state(4, 5) anno = model.generate_annotations(100) samples = model.sample_posterior_over_accuracy(anno, 10) theta_view = plot_theta_tensor(model, 2, samples, title='Debug plot_theta_parameters') return model, theta_view if __name__ == '__main__': model, theta_view = main()

Table Of Contents