Source code for pyanno.database

# Copyright (c) 2011, Enthought, Ltd.
# Author: Pietro Berkes <pberkes@enthought.com>
# License: Modified BSD license (2-clause)

"""Defines a database object to store model results."""

import shelve
from traits.has_traits import HasStrictTraits
from traits.trait_types import Float, Instance
from pyanno.abstract_model import AbstractModel
from pyanno.annotations import AnnotationsContainer
from pyanno.util import PyannoValueError
import numpy as np

[docs]class PyannoResult(HasStrictTraits): """Class for database entries """ #: :class:`~pyanno.annotations.AnnotationsContainer` object anno_container = Instance(AnnotationsContainer) #: pyAnno model (subclass of :class:`~pyanno.abstract_model.AbstractModel`) model = Instance(AbstractModel) #: value of the model performance (usually the log likelihood) value = Float
[docs]class PyannoDatabase(object): """Database to store model results. The database is based on :mod:`shelve`. Keys are strings that uniquely identify data sets. Values are lists of :class:`PyannoResult` objects, which contain a copy of the annotations, the pyanno model that has been applied on them, and the value of the log likelihood of the annotations given the model. """ def __init__(self, filename): self.db_filename = filename #: `shelve` database storing the models self.database = shelve.open(filename, flag='c', protocol=2) #: True if the database is closed self.closed = False
[docs] def store_result(self, data_id, anno_container, model, value): """Store a pyAnno result in the database. The `data_id` must be a **unique** identifier for an annotations set. Arguments --------- data_id : string Readable **unique** identifier for the annotations set (e.g., the file name where the annotations are stored). anno_container : AnnotationsContainer An annotations container (see :class:`~pyanno.annotations.AnnotationsContainer`). model : object pyAnno model object instance (subclass of :class:`~pyanno.abstract_model.AbstractModel` value : float Value of the objective function for the model-annotations pair, typically the log likelihood of the annotations given the model """ entry = PyannoResult(anno_container=anno_container, model=model, value=value) self._check_consistency(data_id, anno_container) # NOTE shelves to not automatically handle changing mutable values, # we need to take care of it manually if not self.database.has_key(data_id): temp = [] else: temp = self.database[data_id] temp.append(entry) self.database[data_id] = temp
[docs] def retrieve_id(self, data_id): """Return all entries with given data ID. Arguments --------- data_id : string Readable **unique** identifier for the annotations set """ return self.database[data_id]
[docs] def remove(self, data_id, idx): """Remove entry from database. Arguments --------- data_id : string Readable **unique** identifier for the annotations set idx : int Index in the list of entries with id `data_id` """ temp = self.database[data_id] del temp[idx] self.database[data_id] = temp
[docs] def close(self): """Close database.""" self.database.close() self.closed = True
[docs] def get_available_id(self): """Return an data ID that has is not present in the database. The returned IDs have the form "<new_data_N>", where N is an integer number. """ n = 0 while True: id = '<name_{}>'.format(n) if not self.database.has_key(id): break n += 1 return id
def _check_consistency(self, data_id, anno_container): """Make sure that all entries with same ID have the same annotations. """ if self.database.has_key(data_id): previous = self.database[data_id] if len(previous) > 0: # check if the new annotations are the same as the previous if not np.all(previous[0].anno_container.annotations == anno_container.annotations): msg = ('Conflicting annotations with same ID. Please ' 'rename the new entry.') raise PyannoValueError(msg)

Table Of Contents