Source code for envisage.provider_extension_registry
# (C) Copyright 2007-2024 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
""" An extension registry implementation with multiple providers. """
# Standard library imports.
import logging
# Enthought library imports.
from traits.api import List, on_trait_change, provides
# Local imports.
from .extension_registry import ExtensionRegistry
from .i_extension_provider import IExtensionProvider
from .i_provider_extension_registry import IProviderExtensionRegistry
# Logging.
logger = logging.getLogger(__name__)
[docs]@provides(IProviderExtensionRegistry)
class ProviderExtensionRegistry(ExtensionRegistry):
"""An extension registry implementation with multiple providers."""
#### Protected 'ProviderExtensionRegistry' interface ######################
# The extension providers that populate the registry.
_providers = List(IExtensionProvider)
###########################################################################
# 'IExtensionRegistry' interface.
###########################################################################
[docs] def set_extensions(self, extension_point_id, extensions):
"""Set the extensions to an extension point."""
raise TypeError("extension points cannot be set")
###########################################################################
# 'ProviderExtensionRegistry' interface.
###########################################################################
[docs] def add_provider(self, provider):
"""Add an extension provider."""
events = self._add_provider(provider)
for extension_point_id, (refs, added, index) in events.items():
self._call_listeners(refs, extension_point_id, added, [], index)
[docs] def get_providers(self):
"""Return all of the providers in the registry."""
return self._providers[:]
[docs] def remove_provider(self, provider):
"""Remove an extension provider.
Raise a 'ValueError' if the provider is not in the registry.
"""
events = self._remove_provider(provider)
for extension_point_id, (refs, removed, index) in events.items():
self._call_listeners(refs, extension_point_id, [], removed, index)
###########################################################################
# Protected 'ExtensionRegistry' interface.
###########################################################################
def _get_extensions(self, extension_point_id):
"""Return the extensions for the given extension point."""
# If we don't know about the extension point then it sure ain't got
# any extensions!
if extension_point_id not in self._extension_points:
logger.warning(
"getting extensions of unknown extension point <%s>"
% extension_point_id
)
extensions = []
# Has this extension point already been accessed?
elif extension_point_id in self._extensions:
extensions = self._extensions[extension_point_id]
# If not, then ask each provider for its contributions to the extension
# point.
else:
extensions = self._initialize_extensions(extension_point_id)
self._extensions[extension_point_id] = extensions
# We store the extensions as a list of lists, with each inner list
# containing the contributions from a single provider. Here we just
# concatenate them into a single list.
#
# You could use a list comprehension, here:-
#
# all = [x for y in extensions for x in y]
#
# But I'm sure that that makes it any clearer ;^)
all = []
for extensions_of_single_provider in extensions:
all.extend(extensions_of_single_provider)
return all
###########################################################################
# Protected 'ProviderExtensionRegistry' interface.
###########################################################################
def _add_provider(self, provider):
"""Add a new provider."""
# Add the provider's extension points.
self._add_provider_extension_points(provider)
# Add the provider's extensions.
events = self._add_provider_extensions(provider)
# And finally, tag it into the list of providers.
self._providers.append(provider)
return events
def _add_provider_extensions(self, provider):
"""Add a provider's extensions to the registry."""
# Each provider can contribute to multiple extension points, so we
# build up a dictionary of the 'ExtensionPointChanged' events that we
# need to fire.
events = {}
# Does the provider contribute any extensions to an extension point
# that has already been accessed?
for extension_point_id, extensions in self._extensions.items():
new = provider.get_extensions(extension_point_id)
# We only need fire an event for this extension point if the
# provider contributes any extensions.
if len(new) > 0:
index = sum(map(len, extensions))
refs = self._get_listener_refs(extension_point_id)
events[extension_point_id] = (refs, new[:], index)
extensions.append(new)
return events
def _add_provider_extension_points(self, provider):
"""Add a provider's extension points to the registry."""
for extension_point in provider.get_extension_points():
self._extension_points[extension_point.id] = extension_point
def _remove_provider(self, provider):
"""Remove a provider."""
# Remove the provider's extensions.
events = self._remove_provider_extensions(provider)
# Remove the provider's extension points.
self._remove_provider_extension_points(provider, events)
# And finally take it out of the list of providers.
self._providers.remove(provider)
return events
def _remove_provider_extensions(self, provider):
"""Remove a provider's extensions from the registry."""
# Each provider can contribute to multiple extension points, so we
# build up a dictionary of the 'ExtensionPointChanged' events that we
# need to fire.
events = {}
# Find the index of the provider in the provider list. Its
# contributions are at the same index in the extensions list of lists.
index = self._providers.index(provider)
# Does the provider contribute any extensions to an extension point
# that has already been accessed?
for extension_point_id, extensions in self._extensions.items():
old = extensions[index]
# We only need fire an event for this extension point if the
# provider contributed any extensions.
if len(old) > 0:
offset = sum(map(len, extensions[:index]))
refs = self._get_listener_refs(extension_point_id)
events[extension_point_id] = (refs, old[:], offset)
del extensions[index]
return events
def _remove_provider_extension_points(self, provider, events):
"""Remove a provider's extension points from the registry."""
for extension_point in provider.get_extension_points():
# Remove the extension point.
del self._extension_points[extension_point.id]
###########################################################################
# Private interface.
###########################################################################
#### Trait change handlers ################################################
@on_trait_change("_providers:extension_point_changed")
def _providers_extension_point_changed(self, obj, trait_name, old, event):
"""Dynamic trait change handler."""
logger.debug("provider <%s> extension point changed", obj)
extension_point_id = event.extension_point_id
# If the extension point has not yet been accessed then we don't fire a
# changed event.
#
# This is because we only access extension points lazily and so we
# can't tell what has actually changed because we have nothing to
# compare it to!
if extension_point_id not in self._extensions:
return
# This is a list of lists where each inner list contains the
# contributions made to the extension point by a single provider.
#
# fixme: This causes a problem if the extension point has not yet been
# accessed! The tricky thing is that if it hasn't been accessed yet
# how do we know what has changed?!? Maybe we should just return an
# empty list instead of barfing!
extensions = self._extensions[extension_point_id]
# Find the index of the provider in the provider list. Its
# contributions are at the same index in the extensions list of lists.
provider_index = self._providers.index(obj)
# Get the updated list from the provider.
extensions[provider_index] = obj.get_extensions(extension_point_id)
# Find where the provider's contributions are in the whole 'list'.
offset = sum(map(len, extensions[:provider_index]))
# Translate the event index from one that refers to the list of
# contributions from the provider, to the list of contributions from
# all providers.
index = self._translate_index(event.index, offset)
# Find out who is listening.
refs = self._get_listener_refs(extension_point_id)
# Let any listeners know that the extensions have been added.
self._call_listeners(
refs, extension_point_id, event.added, event.removed, index
)
#### Methods ##############################################################
def _initialize_extensions(self, extension_point_id):
"""Initialize the extensions to an extension point."""
# We store the extensions as a list of lists, with each inner list
# containing the contributions from a single provider.
extensions = []
for provider in self._providers:
extensions.append(provider.get_extensions(extension_point_id)[:])
logger.debug("extensions to <%s> <%s>", extension_point_id, extensions)
return extensions
def _translate_index(self, index, offset):
"""Translate an event index by the given offset."""
if isinstance(index, slice):
index = slice(
index.start + offset, index.stop + offset, index.step
)
else:
index = index + offset
return index