Source code for traits.trait_set_object

# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

import copy
import copyreg
from itertools import chain
from weakref import ref

from traits.observation.i_observable import IObservable
from traits.trait_base import _validate_everything
from traits.trait_errors import TraitError


[docs]class TraitSetEvent(object): """ An object reporting in-place changes to a traits sets. Parameters ---------- removed : set, optional Old values that were removed from the set. added : set, optional New values added to the set. Attributes ---------- removed : set Old values that were removed from the set. added : set New values added to the set. """ def __init__(self, *, removed=None, added=None): if removed is None: removed = set() self.removed = removed if added is None: added = set() self.added = added def __repr__(self): return ( f"{self.__class__.__name__}(" f"removed={self.removed!r}, " f"added={self.added!r})" )
[docs]@IObservable.register class TraitSet(set): """ A subclass of set that validates and notifies listeners of changes. Parameters ---------- value : iterable, optional Iterable providing the items for the set. item_validator : callable, optional Called to validate and/or transform items added to the set. The callable should accept a single item and return the transformed item, raising TraitError for invalid items. If not given, no item validation is performed. notifiers : list of callable, optional A list of callables with the signature:: notifier(trait_set, removed, added) Where 'added' is a set containing new values that have been added. And 'removed' is a set containing old values that have been removed. If this argument is not given, the list of notifiers is initially empty. Attributes ---------- item_validator : callable Called to validate and/or transform items added to the set. The callable should accept a single item and return the transformed item, raising TraitError for invalid items. notifiers : list of callable A list of callables with the signature:: notifier(trait_set, removed, added) where 'added' is a set containing new values that have been added and 'removed' is a set containing old values that have been removed. """ def __new__(cls, *args, **kwargs): self = super().__new__(cls) self.item_validator = _validate_everything self.notifiers = [] return self def __init__(self, value=(), *, item_validator=None, notifiers=None): if item_validator is not None: self.item_validator = item_validator super().__init__(self.item_validator(item) for item in value) if notifiers is not None: self.notifiers = notifiers
[docs] def notify(self, removed, added): """ Call all notifiers. This simply calls all notifiers provided by the class, if any. The notifiers are expected to have the signature:: notifier(trait_set, removed, added) Any return values are ignored. Any exceptions raised are not handled. Notifiers are therefore expected not to raise any exceptions under normal use. Parameters ---------- removed : set The items that have been removed. added : set The new items that have been added to the set. """ for notifier in self.notifiers: notifier(self, removed, added)
# -- set interface ------------------------------------------------------- def __iand__(self, value): """ Return self &= value. Parameters ---------- value : set or frozenset A value. Returns ------- self : TraitSet The updated set. """ old_set = self.copy() retval = super().__iand__(value) removed = old_set.difference(self) if len(removed) > 0: self.notify(removed, set()) return retval def __ior__(self, value): """ Return self |= value. Parameters ---------- value : set or frozenset A value. Returns ------- self : TraitSet The updated set. """ old_set = self.copy() # Validate each item in value, only if value is a set or frozenset. # We do not want to convert any other iterable type to a set # so that super().__ior__ raises the appropriate error message # for all other iterables. if isinstance(value, (set, frozenset)): value = {self.item_validator(item) for item in value} retval = super().__ior__(value) added = self.difference(old_set) if len(added) > 0: self.notify(set(), added) return retval def __isub__(self, value): """ Return self-=value. Parameters ---------- value : set or frozenset A value. Returns ------- self : TraitSet The updated set. """ old_set = self.copy() retval = super().__isub__(value) removed = old_set.difference(self) if len(removed) > 0: self.notify(removed, set()) return retval def __ixor__(self, value): """ Return self ^= value. Parameters ---------- value : set or frozenset A value. Returns ------- self : TraitSet The updated set. """ removed = set() added = set() # Validate each item in value, only if value is a set or frozenset. # We do not want to convert any other iterable type to a set # so that super().__ixor__ raises the appropriate error message # for all other iterables. if isinstance(value, (set, frozenset)): values = set(value) removed = self.intersection(values) raw_added = values.difference(removed) validated_added = {self.item_validator(item) for item in raw_added} added = validated_added.difference(self) value = added | removed retval = super().__ixor__(value) if removed or added: self.notify(removed, added) return retval
[docs] def add(self, value): """ Add an element to a set. This has no effect if the element is already present. Parameters ---------- value : any The value to add to the set. """ value = self.item_validator(value) value_in_self = value in self super().add(value) if not value_in_self: self.notify(set(), {value})
[docs] def clear(self): """ Remove all elements from this set. """ removed = set(self) super().clear() if removed: self.notify(removed, set())
[docs] def discard(self, value): """ Remove an element from the set if it is a member. If the element is not a member, do nothing. Parameters ---------- value : any An item in the set """ value_in_self = value in self super().discard(value) if value_in_self: self.notify({value}, set())
[docs] def difference_update(self, *args): """ Remove all elements of another set from this set. Parameters ---------- args : iterables The other iterables. """ old_set = self.copy() super().difference_update(*args) removed = old_set.difference(self) if len(removed) > 0: self.notify(removed, set())
[docs] def intersection_update(self, *args): """ Update the set with the intersection of itself and another set. Parameters ---------- args : iterables The other iterables. """ old_set = self.copy() super().intersection_update(*args) removed = old_set.difference(self) if len(removed) > 0: self.notify(removed, set())
[docs] def pop(self): """ Remove and return an arbitrary set element. Raises KeyError if the set is empty. Returns ------- item : any An element from the set. Raises ------ KeyError If the set is empty. """ removed = super().pop() self.notify({removed}, set()) return removed
[docs] def remove(self, value): """ Remove an element that is a member of the set. If the element is not a member, raise a KeyError. Parameters ---------- value : any An element in the set Raises ------ KeyError If the value is not found in the set. """ super().remove(value) self.notify({value}, set())
[docs] def symmetric_difference_update(self, value): """ Update the set with the symmetric difference of itself and another. Parameters ---------- value : iterable """ values = set(value) removed = self.intersection(values) raw_result = values.difference(removed) validated_result = {self.item_validator(item) for item in raw_result} added = validated_result.difference(self) super().symmetric_difference_update(removed | added) if removed or added: self.notify(removed, added)
[docs] def update(self, *args): """ Update the set with the union of itself and others. Parameters ---------- args : iterables The other iterables. """ validated_values = {self.item_validator(item) for item in chain.from_iterable(args)} added = validated_values.difference(self) super().update(added) if len(added) > 0: self.notify(set(), added)
# -- pickle and copy support ---------------------------------------------- def __deepcopy__(self, memo): """ Perform a deepcopy operation. Notifiers are transient and should not be copied. """ # notifiers are transient and should not be copied result = TraitSet( [copy.deepcopy(x, memo) for x in self], item_validator=copy.deepcopy(self.validator, memo), notifiers=[], ) return result def __getstate__(self): """ Get the state of the object for serialization. Notifiers are transient and should not be serialized. """ result = self.__dict__.copy() # notifiers are transient and should not be serialized del result["notifiers"] return result def __setstate__(self, state): """ Restore the state of the object after serialization. Notifiers are transient and are restored to the empty list. """ state['notifiers'] = [] self.__dict__.update(state) # -- Implement IObservable ------------------------------------------------ def _notifiers(self, force_create): """ Return a list of callables where each callable is a notifier. The list is expected to be mutated for contributing or removing notifiers from the object. Parameters ---------- force_create: boolean Not used here. """ return self.notifiers
[docs]class TraitSetObject(TraitSet): """ A specialization of TraitSet with a default validator and notifier for compatibility with Traits versions before 6.0. Parameters ---------- trait : CTrait The trait that the set has been assigned to. object : HasTraits The object this set belongs to. Can also be None in cases where the set has been disconnected from its HasTraits parent. name : str The name of the trait on the object. value : iterable The initial value of the set. Attributes ---------- trait : CTrait The trait that the set has been assigned to. object : callable A callable that when called with no arguments returns the HasTraits object that this set belongs to, or None if there is no such object. name : str The name of the trait on the object. value : iterable The initial value of the set. """ def __init__(self, trait, object, name, value): self.trait = trait self.object = (lambda: None) if object is None else ref(object) self.name = name self.name_items = None if trait.has_items: self.name_items = name + "_items" super().__init__(value, item_validator=self._validator, notifiers=[self.notifier]) def _validator(self, value): """ Validates the value by calling the inner trait's validate method. Parameters ---------- value : any The value to be validated. Returns ------- value : any The validated value. Raises ------ TraitError On validation failure for the inner trait. """ object_ref = getattr(self, 'object', None) trait = getattr(self, 'trait', None) if object_ref is None or trait is None: return value object = object_ref() # validate the new value(s) validate = trait.item_trait.validate if validate is None: return value try: return validate(object, self.name, value) except TraitError as excp: excp.set_prefix("Each element of the") raise excp
[docs] def notifier(self, trait_set, removed, added): """ Converts and consolidates the parameters to a TraitSetEvent and then fires the event. Parameters ---------- trait_set : set The complete set removed : set Set of values that were removed. added : set Set of values that were added. """ if self.name_items is None: return object = self.object() if object is None: return if getattr(object, self.name) is not self: # Workaround having this set inside another container which # also uses the name_items trait for notification. # Similar to enthought/traits#25 return event = TraitSetEvent(removed=removed, added=added) items_event = self.trait.items_event() object.trait_items_event(self.name_items, event, items_event)
# -- pickle and copy support ---------------------------------------------- def __deepcopy__(self, memo): """ Perform a deepcopy operation. Notifiers are transient and should not be copied. """ result = TraitSetObject( self.trait, None, self.name, {copy.deepcopy(x, memo) for x in self}, ) return result def __getstate__(self): """ Get the state of the object for serialization. Notifiers are transient and should not be serialized. """ result = super().__getstate__() del result["object"] del result["trait"] return result def __setstate__(self, state): """ Restore the state of the object after serialization. Notifiers are transient and are restored to the empty list. """ state.setdefault("name", "") state["notifiers"] = [self.notifier] state["object"] = lambda: None state["trait"] = None self.__dict__.update(state) def __reduce_ex__(self, protocol=None): """ Overridden to make sure we call our custom __getstate__. """ return ( copyreg._reconstructor, (type(self), set, list(self)), self.__getstate__(), )