# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
import copy
import copyreg
from itertools import chain
from weakref import ref
from traits.observation.i_observable import IObservable
from traits.trait_base import _validate_everything
from traits.trait_errors import TraitError
[docs]class TraitSetEvent(object):
""" An object reporting in-place changes to a traits sets.
Parameters
----------
removed : set, optional
Old values that were removed from the set.
added : set, optional
New values added to the set.
Attributes
----------
removed : set
Old values that were removed from the set.
added : set
New values added to the set.
"""
def __init__(self, *, removed=None, added=None):
if removed is None:
removed = set()
self.removed = removed
if added is None:
added = set()
self.added = added
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"removed={self.removed!r}, "
f"added={self.added!r})"
)
[docs]@IObservable.register
class TraitSet(set):
""" A subclass of set that validates and notifies listeners of changes.
Parameters
----------
value : iterable, optional
Iterable providing the items for the set.
item_validator : callable, optional
Called to validate and/or transform items added to the set. The
callable should accept a single item and return the transformed
item, raising TraitError for invalid items. If not given, no
item validation is performed.
notifiers : list of callable, optional
A list of callables with the signature::
notifier(trait_set, removed, added)
Where 'added' is a set containing new values that have been added.
And 'removed' is a set containing old values that have been removed.
If this argument is not given, the list of notifiers is initially
empty.
Attributes
----------
item_validator : callable
Called to validate and/or transform items added to the set. The
callable should accept a single item and return the transformed
item, raising TraitError for invalid items.
notifiers : list of callable
A list of callables with the signature::
notifier(trait_set, removed, added)
where 'added' is a set containing new values that have been added
and 'removed' is a set containing old values that have been removed.
"""
def __new__(cls, *args, **kwargs):
self = super().__new__(cls)
self.item_validator = _validate_everything
self.notifiers = []
return self
def __init__(self, value=(), *, item_validator=None, notifiers=None):
if item_validator is not None:
self.item_validator = item_validator
super().__init__(self.item_validator(item) for item in value)
if notifiers is not None:
self.notifiers = notifiers
[docs] def notify(self, removed, added):
""" Call all notifiers.
This simply calls all notifiers provided by the class, if any.
The notifiers are expected to have the signature::
notifier(trait_set, removed, added)
Any return values are ignored. Any exceptions raised are not
handled. Notifiers are therefore expected not to raise any
exceptions under normal use.
Parameters
----------
removed : set
The items that have been removed.
added : set
The new items that have been added to the set.
"""
for notifier in self.notifiers:
notifier(self, removed, added)
# -- set interface -------------------------------------------------------
def __iand__(self, value):
""" Return self &= value.
Parameters
----------
value : set or frozenset
A value.
Returns
-------
self : TraitSet
The updated set.
"""
old_set = self.copy()
retval = super().__iand__(value)
removed = old_set.difference(self)
if len(removed) > 0:
self.notify(removed, set())
return retval
def __ior__(self, value):
""" Return self |= value.
Parameters
----------
value : set or frozenset
A value.
Returns
-------
self : TraitSet
The updated set.
"""
old_set = self.copy()
# Validate each item in value, only if value is a set or frozenset.
# We do not want to convert any other iterable type to a set
# so that super().__ior__ raises the appropriate error message
# for all other iterables.
if isinstance(value, (set, frozenset)):
value = {self.item_validator(item)
for item in value}
retval = super().__ior__(value)
added = self.difference(old_set)
if len(added) > 0:
self.notify(set(), added)
return retval
def __isub__(self, value):
""" Return self-=value.
Parameters
----------
value : set or frozenset
A value.
Returns
-------
self : TraitSet
The updated set.
"""
old_set = self.copy()
retval = super().__isub__(value)
removed = old_set.difference(self)
if len(removed) > 0:
self.notify(removed, set())
return retval
def __ixor__(self, value):
""" Return self ^= value.
Parameters
----------
value : set or frozenset
A value.
Returns
-------
self : TraitSet
The updated set.
"""
removed = set()
added = set()
# Validate each item in value, only if value is a set or frozenset.
# We do not want to convert any other iterable type to a set
# so that super().__ixor__ raises the appropriate error message
# for all other iterables.
if isinstance(value, (set, frozenset)):
values = set(value)
removed = self.intersection(values)
raw_added = values.difference(removed)
validated_added = {self.item_validator(item) for item in
raw_added}
added = validated_added.difference(self)
value = added | removed
retval = super().__ixor__(value)
if removed or added:
self.notify(removed, added)
return retval
[docs] def add(self, value):
""" Add an element to a set.
This has no effect if the element is already present.
Parameters
----------
value : any
The value to add to the set.
"""
value = self.item_validator(value)
value_in_self = value in self
super().add(value)
if not value_in_self:
self.notify(set(), {value})
[docs] def clear(self):
""" Remove all elements from this set. """
removed = set(self)
super().clear()
if removed:
self.notify(removed, set())
[docs] def discard(self, value):
""" Remove an element from the set if it is a member.
If the element is not a member, do nothing.
Parameters
----------
value : any
An item in the set
"""
value_in_self = value in self
super().discard(value)
if value_in_self:
self.notify({value}, set())
[docs] def difference_update(self, *args):
""" Remove all elements of another set from this set.
Parameters
----------
args : iterables
The other iterables.
"""
old_set = self.copy()
super().difference_update(*args)
removed = old_set.difference(self)
if len(removed) > 0:
self.notify(removed, set())
[docs] def intersection_update(self, *args):
""" Update the set with the intersection of itself and another set.
Parameters
----------
args : iterables
The other iterables.
"""
old_set = self.copy()
super().intersection_update(*args)
removed = old_set.difference(self)
if len(removed) > 0:
self.notify(removed, set())
[docs] def pop(self):
""" Remove and return an arbitrary set element.
Raises KeyError if the set is empty.
Returns
-------
item : any
An element from the set.
Raises
------
KeyError
If the set is empty.
"""
removed = super().pop()
self.notify({removed}, set())
return removed
[docs] def remove(self, value):
""" Remove an element that is a member of the set.
If the element is not a member, raise a KeyError.
Parameters
----------
value : any
An element in the set
Raises
------
KeyError
If the value is not found in the set.
"""
super().remove(value)
self.notify({value}, set())
[docs] def symmetric_difference_update(self, value):
""" Update the set with the symmetric difference of itself and another.
Parameters
----------
value : iterable
"""
values = set(value)
removed = self.intersection(values)
raw_result = values.difference(removed)
validated_result = {self.item_validator(item) for item in raw_result}
added = validated_result.difference(self)
super().symmetric_difference_update(removed | added)
if removed or added:
self.notify(removed, added)
[docs] def update(self, *args):
""" Update the set with the union of itself and others.
Parameters
----------
args : iterables
The other iterables.
"""
validated_values = {self.item_validator(item)
for item in chain.from_iterable(args)}
added = validated_values.difference(self)
super().update(added)
if len(added) > 0:
self.notify(set(), added)
# -- pickle and copy support ----------------------------------------------
def __deepcopy__(self, memo):
""" Perform a deepcopy operation.
Notifiers are transient and should not be copied.
"""
# notifiers are transient and should not be copied
result = TraitSet(
[copy.deepcopy(x, memo) for x in self],
item_validator=copy.deepcopy(self.validator, memo),
notifiers=[],
)
return result
def __getstate__(self):
""" Get the state of the object for serialization.
Notifiers are transient and should not be serialized.
"""
result = self.__dict__.copy()
# notifiers are transient and should not be serialized
del result["notifiers"]
return result
def __setstate__(self, state):
""" Restore the state of the object after serialization.
Notifiers are transient and are restored to the empty list.
"""
state['notifiers'] = []
self.__dict__.update(state)
# -- Implement IObservable ------------------------------------------------
def _notifiers(self, force_create):
""" Return a list of callables where each callable is a notifier.
The list is expected to be mutated for contributing or removing
notifiers from the object.
Parameters
----------
force_create: boolean
Not used here.
"""
return self.notifiers
[docs]class TraitSetObject(TraitSet):
""" A specialization of TraitSet with a default validator and notifier
for compatibility with Traits versions before 6.0.
Parameters
----------
trait : CTrait
The trait that the set has been assigned to.
object : HasTraits
The object this set belongs to. Can also be None in cases where the
set has been disconnected from its HasTraits parent.
name : str
The name of the trait on the object.
value : iterable
The initial value of the set.
Attributes
----------
trait : CTrait
The trait that the set has been assigned to.
object : callable
A callable that when called with no arguments returns the HasTraits
object that this set belongs to, or None if there is no such object.
name : str
The name of the trait on the object.
value : iterable
The initial value of the set.
"""
def __init__(self, trait, object, name, value):
self.trait = trait
self.object = (lambda: None) if object is None else ref(object)
self.name = name
self.name_items = None
if trait.has_items:
self.name_items = name + "_items"
super().__init__(value, item_validator=self._validator,
notifiers=[self.notifier])
def _validator(self, value):
""" Validates the value by calling the inner trait's validate method.
Parameters
----------
value : any
The value to be validated.
Returns
-------
value : any
The validated value.
Raises
------
TraitError
On validation failure for the inner trait.
"""
object_ref = getattr(self, 'object', None)
trait = getattr(self, 'trait', None)
if object_ref is None or trait is None:
return value
object = object_ref()
# validate the new value(s)
validate = trait.item_trait.validate
if validate is None:
return value
try:
return validate(object, self.name, value)
except TraitError as excp:
excp.set_prefix("Each element of the")
raise excp
[docs] def notifier(self, trait_set, removed, added):
""" Converts and consolidates the parameters to a TraitSetEvent and
then fires the event.
Parameters
----------
trait_set : set
The complete set
removed : set
Set of values that were removed.
added : set
Set of values that were added.
"""
if self.name_items is None:
return
object = self.object()
if object is None:
return
if getattr(object, self.name) is not self:
# Workaround having this set inside another container which
# also uses the name_items trait for notification.
# Similar to enthought/traits#25
return
event = TraitSetEvent(removed=removed, added=added)
items_event = self.trait.items_event()
object.trait_items_event(self.name_items, event, items_event)
# -- pickle and copy support ----------------------------------------------
def __deepcopy__(self, memo):
""" Perform a deepcopy operation.
Notifiers are transient and should not be copied.
"""
result = TraitSetObject(
self.trait,
None,
self.name,
{copy.deepcopy(x, memo) for x in self},
)
return result
def __getstate__(self):
""" Get the state of the object for serialization.
Notifiers are transient and should not be serialized.
"""
result = super().__getstate__()
del result["object"]
del result["trait"]
return result
def __setstate__(self, state):
""" Restore the state of the object after serialization.
Notifiers are transient and are restored to the empty list.
"""
state.setdefault("name", "")
state["notifiers"] = [self.notifier]
state["object"] = lambda: None
state["trait"] = None
self.__dict__.update(state)
def __reduce_ex__(self, protocol=None):
""" Overridden to make sure we call our custom __getstate__.
"""
return (
copyreg._reconstructor,
(type(self), set, list(self)),
self.__getstate__(),
)