# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)
"""View for model and data pair."""
from traits.has_traits import HasTraits, on_trait_change
from traits.trait_types import (Any, File, Instance, Button, Enum, Str, Bool,
                                Float, Event, Int)
from traits.traits import Property
from traitsui.editors.range_editor import RangeEditor
from traitsui.group import HGroup, VGroup, Tabbed
from traitsui.handler import ModelView
from traitsui.item import Item, Label, Spring, UItem
from traitsui.menu import OKCancelButtons
from traitsui.view import View
from traitsui.message import error
from pyanno.modelA import ModelA
from pyanno.modelB import ModelB
from pyanno.modelBt import ModelBt
from pyanno.modelBt_loopdesign import ModelBtLoopDesign
from pyanno.annotations import AnnotationsContainer
from pyanno.plots.annotations_plot import PosteriorPlot
from pyanno.ui.annotation_stat_view import AnnotationsStatisticsView
from pyanno.ui.annotations_view import AnnotationsView, CreateNewAnnotationsDialog
from pyanno.ui.appbase.long_running_call import LongRunningCall
from pyanno.ui.appbase.wx_utils import is_display_small
from pyanno.ui.model_a_view import ModelAView
from pyanno.ui.model_bt_view import ModelBtView
from pyanno.ui.model_btloop_view import ModelBtLoopDesignView
from pyanno.ui.model_b_view import ModelBView
import numpy as np
# TODO remember last setting of parameters
from pyanno.ui.posterior_view import PosteriorView
from pyanno.util import PyannoValueError
from traitsui.message import message
import logging
logger = logging.getLogger(__name__)
[docs]class ModelDataView(ModelView):
    #### Information about available models
    model_name = Enum(
        'Model B-with-theta',
        'Model B-with-theta (loop design)',
        'Model B',
        'Model A (loop design)',
    )
    _model_name_to_class = {
        'Model B-with-theta': ModelBt,
        'Model B-with-theta (loop design)': ModelBtLoopDesign,
        'Model B': ModelB,
        'Model A (loop design)': ModelA
    }
    _model_class_to_view = {
        ModelBt: ModelBtView,
        ModelBtLoopDesign: ModelBtLoopDesignView,
        ModelB: ModelBView,
        ModelA: ModelAView
    }
    #### Application-related traits
    # reference to pyanno application
    application = Any
    #### Model-related traits
    # the annotations model
    model = Any
    # Traits UI view of the model
    model_view = Instance(ModelView)
    # fired when the model is updates
    model_updated = Event
    # parameters view should not update when this trait is False
    model_update_suspended = Bool(False)
    #### Annotation-related traits
    # File trait to load a new annotations file
    annotations_file = File
    # True then annotations are loaded correctly
    annotations_are_defined = Bool(False)
    # fired when annotations are updated
    annotations_updated = Event
    # Traits UI view of the annotations
    annotations_view = Instance(AnnotationsView)
    # Traits UI view of the annotations' statistics
    annotations_stats_view = Instance(AnnotationsStatisticsView)
    # shortcut to the annotations
    annotations = Property
    def _get_annotations(self):
        return self.annotations_view.annotations_container.annotations
    # property that combines information from the model and the annotations
    # to give a consistent number of classes
    nclasses = Property
    def _get_nclasses(self):
        return max(self.model.nclasses, self.annotations.max() + 1)
    # info string -- currently not used
    info_string = Str
    # used to display the current log likelihood
    log_likelihood = Float
    def _annotations_view_default(self):
        anno = AnnotationsContainer.from_array([[0]], name='<undefined>')
        return AnnotationsView(annotations_container = anno,
                               nclasses = self.model.nclasses,
                               application = self.application,
                               model=HasTraits())
    @on_trait_change('annotations_file')
    def _update_annotations_file(self):
        logger.info('Load file {}'.format(self.annotations_file))
        anno = AnnotationsContainer.from_file(self.annotations_file)
        self.set_annotations(anno)
    @on_trait_change('annotations_updated,model_updated')
    def _update_log_likelihood(self):
        if self.annotations_are_defined:
            if not self.model.are_annotations_compatible(self.annotations):
                self.log_likelihood = np.nan
            else:
                self.log_likelihood = self.model.log_likelihood(
                    self.annotations)
    @on_trait_change('model.nclasses')
    def _update_nclasses(self):
        self.annotations_view.nclasses = self.model.nclasses
        self.annotations_view.annotations_updated = True
    @on_trait_change('model,model:theta,model:gamma')
    def _fire_model_updated(self):
        if not self.model_update_suspended:
            self.model_updated = True
            if self.model_view is not None:
                self.model_view.model_updated = True
    ### Control content #######################################################
[docs]    def set_model(self, model):
        """Update window with a new model.
        """
        self.model = model
        model_view_class = self._model_class_to_view[model.__class__]
        self.model_view = model_view_class(model=model)
        self.model_updated = True
 
[docs]    def set_annotations(self, annotations_container):
        """Update window with a new set of annotations."""
        self.annotations_view = AnnotationsView(
            annotations_container = annotations_container,
            nclasses = self.model.nclasses,
            application = self.application,
            model = HasTraits()
        )
        self.annotations_stats_view = AnnotationsStatisticsView(
            annotations = self.annotations,
            nclasses = self.nclasses
        )
        self.annotations_are_defined = True
        self.annotations_updated = True
 
[docs]    def set_from_database_record(self, record):
        """Set main window model and annotations from a database record."""
        self.set_model(record.model)
        self.set_annotations(record.anno_container)
    ### Actions ##############################################################
    #### Model creation actions
    # create a new model 
    new_model = Button(label='Create...')
    # show informations about the selected model
    get_info_on_model = Button(label='Info...')
    #### Annotation creation actions
    # create new annotations
    new_annotations = Button(label='Create...')
    #### Model <-> data computations
    # execute Maximum Likelihood estimation of parameters
    ml_estimate = Button(label='ML estimate...',
                         desc=('Maximum Likelihood estimate of model '
                               'parameters'))
    # execute MAP estimation of parameters
    map_estimate = Button(label='MAP estimate...')
    # draw samples from the posterior over accuracy
    sample_posterior_over_accuracy = Button(label='Sample parameters...')
    # compute posterior over label classes
    estimate_labels = Button(label='Estimate labels...')
    #### Database actions
    # open database window
    open_database = Button(label="Open database")
    # add current results to database
    add_to_database = Button(label="Add to database")
    def _new_model_fired(self):
        """Create new model."""
        # delegate creation to associated model_view
        model_name = self.model_name
        model_class = self._model_name_to_class[model_name]
        responsible_view = self._model_class_to_view[model_class]
        # model == None if the user cancelled the action
        model = responsible_view.create_model_dialog(self.info.ui.control)
        if model is not None:
            self.set_model(model)
    def _new_annotations_fired(self):
        """Create an empty annotations set."""
        annotations = CreateNewAnnotationsDialog.create_annotations_dialog()
        if annotations is not None:
            name = self.application.database.get_available_id()
            anno_cont = AnnotationsContainer.from_array(annotations,
                                                        name=name)
            self.set_annotations(anno_cont)
    def _open_database_fired(self):
        """Open database window."""
        if self.application is not None:
            self.application.open_database_window()
    def _get_info_on_model_fired(self):
        """Open dialog with model description."""
        model_class = self._model_name_to_class[self.model_name]
        message(message = model_class.__doc__, title='Model info')
    def _add_to_database_fired(self):
        """Add current results to database."""
        if self.application is not None:
            self.application.add_current_state_to_database()
    def _action_finally(self):
        """Operations that need to be executed both in case of a success and
        a failure of the long-running action.
        """
        self.model_update_suspended = False
    def _action_success(self, result):
        self._action_finally()
        self._fire_model_updated()
    def _action_failure(self, err):
        self._action_finally()
        if isinstance(err, PyannoValueError):
            errmsg = err.args[0]
            if 'Annotations' in errmsg:
                # raised when annotations are incompatible with the model
                error('Error: ' + errmsg)
            else:
                # re-raise exception if it has not been handled
                raise err
    def _action_on_model(self, message, method, args=None, kwargs=None,
                         on_success=None, on_failure=None):
        """Call long running method on model.
        While the call is running, a window with a pulse progress bar is
        displayed.
        An error message is displayed if the call raises a PyannoValueError
        (raised when annotations are incompatible with the current model).
        """
        if args is None: args = []
        if kwargs is None: kwargs = {}
        if on_success is None: on_success = self._action_success
        if on_failure is None: on_failure = self._action_failure
        self.model_update_suspended = True
        call = LongRunningCall(
            parent     = None,
            title      = 'Calculating...',
            message    = message,
            callable   = method,
            args       = args,
            kw         = kwargs,
            on_success = on_success,
            on_failure = on_failure,
        )
        call()
    def _ml_estimate_fired(self):
        """Run ML estimation of parameters."""
        message = 'Computing maximum likelihood estimate'
        self._action_on_model(message, self.model.mle, args=[self.annotations])
    def _map_estimate_fired(self):
        """Run ML estimation of parameters."""
        message = 'Computing maximum a posteriori estimate'
        self._action_on_model(message, self.model.map, args=[self.annotations])
    def _sample_posterior_success(self, samples):
        if (samples is not None
            and hasattr(self.model_view, 'plot_theta_samples')):
            self.model_view.plot_theta_samples(samples)
        self._action_finally()
    def _sample_posterior_over_accuracy_fired(self):
        """Sample the posterior of the parameters `theta`."""
        message = 'Sampling from the posterior over accuracy'
        # open dialog asking for number of samples
        params = _SamplingParamsDialog()
        dialog_ui = params.edit_traits(kind='modal')
        if not dialog_ui.result:
            # user pressed "Cancel"
            return
        nsamples = params.nsamples
        self._action_on_model(
            message,
            self.model.sample_posterior_over_accuracy,
            args   = [self.annotations, nsamples],
            kwargs = {'burn_in_samples': params.burn_in_samples,
                    'thin_samples'   : params.thin_samples},
            on_success=self._sample_posterior_success
        )
    def _estimate_labels_success(self, posterior):
        if posterior is not None:
            post_plot = PosteriorPlot(posterior=posterior,
                                      title='Posterior over classes')
            post_view = PosteriorView(posterior_plot=post_plot,
                                      annotations=self.annotations)
            post_view.edit_traits()
        self._action_finally()
    def _estimate_labels_fired(self):
        """Compute the posterior over annotations and show it in a new window"""
        message = 'Computing the posterior over classes'
        self._action_on_model(
            message,
            self.model.infer_labels,
            args=[self.annotations],
            on_success=self._estimate_labels_success
        )
    ### Views ################################################################
[docs]    def traits_view(self):
        ## Model view
        # adjust sizes to display size
        if is_display_small():
            # full view size
            w_view, h_view = 1024, 768
            w_data_create_group = 350
            w_data_info_group = 500
            h_annotations_stats = 270
        else:
            w_view, h_view = 1300, 850
            w_data_create_group = 400
            w_data_info_group = 700
            h_annotations_stats = 330
        model_create_group = (
            VGroup(
                HGroup(
                    UItem(name='model_name',width=200),
                    UItem(name='new_model', width=100),
                    UItem(name='get_info_on_model', width=100, height=25),
                ),
                label = 'Create new model'
            )
        )
        model_group = (
            VGroup (
                model_create_group,
                VGroup(
                    Item(
                        'model_view',
                        style='custom',
                        show_label=False,
                        width=400
                    ),
                    label = 'Model view',
                ),
            ),
        )
        ## Data view
        data_create_group = VGroup(
            #Label('Open annotation file:', width=800),
            HGroup(
                Item('annotations_file', style='simple', label='Open file:',
                     width=w_data_create_group, height=25),
                UItem('new_annotations', height=25)
            ),
            label = 'Load/create annotations',
            show_border = False,
        )
        data_info_group = VGroup(
            Item('annotations_view',
                 style='custom',
                 show_label=False,
                 visible_when='annotations_are_defined',
                 width=w_data_info_group,
            ),
            Item('annotations_stats_view',
                 style='custom',
                 show_label=False,
                 visible_when='annotations_are_defined',
                 height=h_annotations_stats),
            label = 'Data view',
        )
        data_group = (
            VGroup (
                data_create_group,
                data_info_group,
            ),
        )
        ## (Model,Data) view
        model_data_group = (
            VGroup(
                #Item('info_string', show_label=False, style='readonly'),
                Item('log_likelihood', label='Log likelihood', style='readonly'),
                HGroup(
                    Item('ml_estimate',
                         enabled_when='annotations_are_defined'),
                    Item('map_estimate',
                         enabled_when='annotations_are_defined'),
                    Item('sample_posterior_over_accuracy',
                         enabled_when='annotations_are_defined'),
                    Item('estimate_labels',
                         enabled_when='annotations_are_defined'),
                    Spring(),
                    Item('add_to_database',
                         enabled_when='annotations_are_defined'),
                    Item('open_database'),
                    show_labels=False,
                ),
                label = 'Model-data view',
            )
        )
        ## Full view
        full_view = View(
            VGroup(
                HGroup(
                    model_group,
                    data_group
                ),
                model_data_group,
            ),
            title='PyAnno - Models of data annotations by multiple curators',
            width = w_view,
            height = h_view,
            resizable = False
        )
        return full_view
  
class _SamplingParamsDialog(HasTraits):
    nsamples = Int(200)
    burn_in_samples = Int(100)
    thin_samples = Int(1)
    traits_view = View(
        VGroup(
            Item('nsamples',
                 label  = 'Number of samples',
                 editor = RangeEditor(mode='spinner',
                                      low=100, high=50000,
                                      is_float=False),
                 width = 100
            ),
            Item('burn_in_samples',
                 label  = 'Number of samples in burn-in phase',
                 editor = RangeEditor(mode='spinner',
                                      low=1, high=50000,
                                      is_float=False),
                 width = 100
            ),
            Item('thin_samples',
                 label  = 'Thinning (keep 1 samples every N)',
                 editor = RangeEditor(mode='spinner',
                                      low=1, high=50000,
                                      is_float=False),
                 width = 100
            ),
        ),
        buttons = OKCancelButtons
    )
#### Testing and debugging ####################################################
[docs]def main():
    """ Entry point for standalone testing/debugging. """
    from pyanno.ui.model_data_view import ModelDataView
    model = ModelBtLoopDesign.create_initial_state(5)
    model_data_view = ModelDataView()
    model_data_view.set_model(model)
    # open model_data_view
    model_data_view.configure_traits(view='traits_view')
    return model, model_data_view
 
if __name__ == '__main__':
    m, mdv = main()