Source code for pyanno.plots.theta_plot

# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)

"""TraitsUI view of the Theta parameters, and their samples."""

from chaco.array_plot_data import ArrayPlotData
from chaco.data_range_2d import DataRange2D
from chaco.label_axis import LabelAxis
from chaco.legend import Legend
from chaco.plot import Plot
from chaco.plot_containers import VPlotContainer
from chaco.scales.scales import FixedScale
from chaco.scales_tick_generator import ScalesTickGenerator
from chaco.tools.legend_tool import LegendTool
from enable.component_editor import ComponentEditor

from traits.has_traits import on_trait_change
from traits.trait_numeric import Array
from traits.trait_types import Instance, Bool, Event, Str, DictStrAny, Any
from traitsui.handler import ModelView
from traitsui.item import Item

import numpy as np
from pyanno.plots.plot_tools import get_annotator_color
from pyanno.plots.plots_superclass import PyannoPlotContainer
from pyanno.ui.appbase.wx_utils import is_display_small


def _w_idx(str_, idx):
    """Append number to string. Used to generate PlotData labels"""
    return str_ + str(idx)


class ThetaScatterPlot(ModelView, PyannoPlotContainer):
    """Defines a view of the annotator accuracy parameters, theta.

    The view consists in a Chaco plot that displays the theta parameter for
    each annotator, and samples from the posterior distribution over theta
    with a combination of a scatter plot and a candle plot.
    """

    #### Traits definition ####################################################

    theta_samples_valid = Bool(False)
    theta_samples = Array(dtype=float, shape=(None, None))

    # return value for "Copy" action on plot
    data = DictStrAny

    def _data_default(self):
        return {'theta': self.model.theta, 'theta_samples': None}

    @on_trait_change('redraw,theta_samples,theta_samples_valid')
    def _update_data(self):
        if self.theta_samples_valid:
            theta_samples = self.theta_samples
        else:
            theta_samples = None

        self.data['theta'] = self.model.theta
        self.data['theta_samples'] = theta_samples

    #### plot-related traits
    title = Str('Accuracy (theta)')

    theta_plot_data = Instance(ArrayPlotData)
    theta_plot = Any

    redraw = Event


    ### Plot definition #######################################################

    def _compute_range2d(self):
        low = min(0.6, self.model.theta.min()-0.05)
        if self.theta_samples_valid:
            low = min(low, self.theta_samples.min()-0.05)
        range2d = DataRange2D(low=(0., low),
                              high=(self.model.theta.shape[0]+1, 1.))
        return range2d


    @on_trait_change('redraw', post_init=True)
    def _update_range2d(self):
        self.theta_plot.range2d = self._compute_range2d()


    def _theta_plot_default(self):
        """Create plot of theta parameters."""

        # We plot both the thetas and the samples from the posterior; if the
        # latter are not defined, the corresponding ArrayPlotData names
        # should be set to an empty list, so that they are not displayed
        theta = self.model.theta
        theta_len = theta.shape[0]

        # create the plot data
        if not self.theta_plot_data:
            self.theta_plot_data = ArrayPlotData()
            self._update_plot_data()

        # create the plot
        theta_plot = Plot(self.theta_plot_data)

        for idx in range(theta_len):
            # candle plot summarizing samples over the posterior
            theta_plot.candle_plot((_w_idx('index', idx),
                                    _w_idx('min', idx),
                                    _w_idx('barmin', idx),
                                    _w_idx('avg', idx),
                                    _w_idx('barmax', idx),
                                    _w_idx('max', idx)),
                                    color = get_annotator_color(idx),
                                    bar_line_color = "black",
                                    stem_color = "blue",
                                    center_color = "red",
                                    center_width = 2)

            # plot of raw samples
            theta_plot.plot((_w_idx('ysamples', idx),
                             _w_idx('xsamples', idx)),
                            type='scatter',
                            color='black',
                            marker='dot',
                            line_width=0.5,
                            marker_size=1)

            # plot current parameters
            theta_plot.plot((_w_idx('y', idx), _w_idx('x', idx)),
                            type='scatter',
                            color=get_annotator_color(idx),
                            marker='plus',
                            marker_size=8,
                            line_width=2)

        # adjust axis bounds
        theta_plot.range2d = self._compute_range2d()

        # remove horizontal grid and axis
        theta_plot.underlays = [theta_plot.x_grid, theta_plot.y_axis]

        # create new horizontal axis
        label_list = [str(i) for i in range(1, theta_len+1)]

        label_axis = LabelAxis(
            theta_plot,
            orientation = 'bottom',
            positions = range(1, theta_len+1),
            labels = label_list,
            label_rotation = 0
        )
        # use a FixedScale tick generator with a resolution of 1
        label_axis.tick_generator = ScalesTickGenerator(scale=FixedScale(1.))

        theta_plot.index_axis = label_axis
        theta_plot.underlays.append(label_axis)
        theta_plot.padding = 25
        theta_plot.padding_left = 40
        theta_plot.aspect_ratio = 1.0

        container = VPlotContainer()
        container.add(theta_plot)
        container.bgcolor = 0xFFFFFF

        self.decorate_plot(container, theta)
        self._set_title(theta_plot)

        return container


    ### Handle plot data ######################################################

    def _samples_names_and_values(self, idx):
        """Return a list of names and values for the samples PlotData."""

        # In the following code, we rely on lazy evaluation of the
        # X if CONDITION else Y statements to return a default value if the
        # theta samples are not currently defined, or the real value if they
        # are.

        invalid = not self.theta_samples_valid
        samples = [] if invalid else np.sort(self.theta_samples[:,idx])
        nsamples = None if invalid else samples.shape[0]
        perc5 = None if invalid else samples[int(nsamples*0.05)]
        perc95 = None if invalid else samples[int(nsamples*0.95)]

        data_dict = {
            'xsamples':
                [] if invalid else samples,
            'ysamples':
                [] if invalid else (
                    np.random.random(size=(nsamples,))*0.1-0.05 + idx + 1.2
                    ),
            'min':
                [] if invalid else [perc5],
            'max':
                [] if invalid else [perc95],
            'barmin':
                [] if invalid else [samples.mean() - samples.std()],
            'barmax':
                [] if invalid else [samples.mean() + samples.std()],
            'avg':
                [] if invalid else [samples.mean()],
            'index':
                [] if invalid else [idx + 0.8]
        }

        name_value = [(_w_idx(name, idx), value)
                      for name, value in data_dict.items()]
        return name_value

    @on_trait_change('theta_plot_data,theta_samples_valid,redraw')
    def _update_plot_data(self):
        """Updates PlotData on changes."""
        theta = self.model.theta

        plot_data = self.theta_plot_data

        if plot_data is not None:
            for idx, th in enumerate(theta):
                plot_data.set_data('x%d' % idx, [th])
                plot_data.set_data('y%d' % idx, [idx+1.2])

                for name_value in self._samples_names_and_values(idx):
                    name, value = name_value
                    plot_data.set_data(name, value)


    #### View definition #####################################################

    resizable_plot_item = Item(
        'theta_plot',
        editor=ComponentEditor(),
        resizable=True,
        show_label=False,
        width=600,
        height=400
        )

    traits_plot_item = Instance(Item)

    def _traits_plot_item_default(self):
        height = -220 if is_display_small() else -280
        return Item('theta_plot',
                    editor=ComponentEditor(),
                    resizable=False,
                    show_label=False,
                    height=height
                    )


class ThetaDistrPlot(PyannoPlotContainer):
    """Defines a view of the annotator accuracy parameters, theta.

    The view consists in a Chaco plot that displays the theta parameter for
    each annotator, and samples from the posterior distribution over theta
    as a discretized distribution over theta.
    """

    # reference to the theta tensor for one annotator
    theta = Array

    # reference to an array of samples for theta for one annotator
    theta_samples = Any

    # chaco plot of the tensor
    theta_plot = Any

    def _theta_name(self, k):
        nannotators = self.theta.shape[0]
        ndigits = int(np.ceil(np.log10(nannotators)))
        format_str = 'theta[{{:{}d}}]'.format(ndigits)
        return format_str.format(k)

    def _theta_plot_default(self):
        theta = self.theta
        nannotators = theta.shape[0]
        samples = self.theta_samples

        # plot data object
        plot_data = ArrayPlotData()

        # create the plot
        plot = Plot(plot_data)

        # --- plot theta as vertical dashed lines
        # add vertical lines extremes
        plot_data.set_data('line_extr', [0., 1.])

        for k in range(nannotators):
            name = self._theta_name(k)
            plot_data.set_data(name, [theta[k], theta[k]])

        plots = {}
        for k in range(nannotators):
            name = self._theta_name(k)
            line_plot = plot.plot(
                (name, 'line_extr'),
                line_width = 2.,
                color = get_annotator_color(k),
                line_style = 'dash',
                name = name
            )
            plots[name] = line_plot

        # --- plot samples as distributions
        if samples is not None:
            bins = np.linspace(0., 1., 100)
            max_hist = 0.
            for k in range(nannotators):
                name = self._theta_name(k) + '_distr_'
                hist, x = np.histogram(samples[:,k], bins=bins)
                hist = hist / float(hist.sum())
                max_hist = max(max_hist, hist.max())

                # make "bars" out of histogram values
                y = np.concatenate(([0], np.repeat(hist, 2), [0]))
                plot_data.set_data(name+'x', np.repeat(x, 2))
                plot_data.set_data(name+'y', y)

            for k in range(nannotators):
                name = self._theta_name(k) + '_distr_'
                plot.plot((name+'x', name+'y'),
                          line_width = 2.,
                          color = get_annotator_color(k)
                          )

        # --- adjust plot appearance

        plot.aspect_ratio = 1.6 if is_display_small() else 1.7
        plot.padding = [20,0,10,40]

        # adjust axis bounds
        x_low, x_high = theta.min(), theta.max()
        y_low, y_high = 0., 1.
        if samples is not None:
            x_high = max(x_high, samples.max())
            x_low = min(x_low, samples.min())
            y_high = max_hist

        plot.range2d = DataRange2D(
            low  = (max(x_low-0.05, 0.), y_low),
            high = (min(x_high*1.1, 1.), min(y_high*1.1, 1.))
        )

        # label axes
        plot.value_axis.title = 'Probability'
        plot.index_axis.title = 'Theta'

        # add legend
        legend = Legend(component=plot, plots=plots,
                        align="ul", padding=5)
        legend.tools.append(LegendTool(legend, drag_button="left"))
        plot.overlays.append(legend)

        container = VPlotContainer()
        container.add(plot)
        container.bgcolor = 0xFFFFFF

        self.decorate_plot(container, theta)
        self._set_title(plot)

        return container


    #### View definition #####################################################

    resizable_plot_item = Item(
        'theta_plot',
        editor=ComponentEditor(),
        resizable=True,
        show_label=False,
        width = 600
        )

    traits_plot_item = Instance(Item)

    def _traits_plot_item_default(self):
        height = -220 if is_display_small() else -280
        return Item(
                    'theta_plot',
                    editor=ComponentEditor(),
                    resizable=False,
                    show_label=False,
                    height=height,
                    )


[docs]def plot_theta_parameters(model, theta_samples=None, type='distr', **kwargs): """Display a plot of the annotator accuracy parameters, theta. This class gives a graphical representation of the `theta` accuracy parameters for :class:`~pyanno.modelA.ModelA`, :class:`~pyanno.modelBt.ModelBt`, and :class:`~pyanno.modelBt_loopdesign.ModelBtLoopDesign`. Arguments --------- model : instance an instance of :class:`~pyanno.modelA.ModelA`, :class:`~pyanno.modelBt.ModelBt`, or :class:`~pyanno.modelBt_loopdesign.ModelBtLoopDesignModelBt` theta_samples : ndarray, shape = (n_items, n_annotators) Samples from the posterior over theta, as returned by the method `model.sample_posterior_over_accuracy`. If given, they are displayed as a probability distribution superimposed to the values or `model.theta`. type : string Either 'scatter' or 'distr'. Parametrizes two different kind of plots. kwargs : dictionary Additional keyword arguments passed to the plot. The argument 'title' sets the title of the plot. """ if type == 'distr': theta_view = ThetaDistrPlot(theta = model.theta, theta_samples = theta_samples) else: theta_view = ThetaScatterPlot(model=model, **kwargs) if theta_samples is not None: theta_view.theta_samples = theta_samples theta_view.theta_samples_valid = True theta_view.configure_traits(view='resizable_view') return theta_view #### Testing and debugging ####################################################
def main(): """ Entry point for standalone testing/debugging. """ from pyanno.modelBt_loopdesign import ModelBtLoopDesign model = ModelBtLoopDesign.create_initial_state(5) annotations = model.generate_annotations(100) theta_samples = model.sample_posterior_over_accuracy( annotations, 100, step_optimization_nsamples = 3 ) theta_view = plot_theta_parameters(model, theta_samples, type='distr', title='Debug plot_theta_parameters') return model, theta_view if __name__ == '__main__': model, theta_view = main()

Table Of Contents