Source code for smqtk_classifier.interfaces.classification_element

import abc
from typing import Any, Dict, Hashable, Mapping, Optional, Tuple, Type, TypeVar

from smqtk_core import Plugfigurable
from smqtk_classifier.exceptions import NoClassificationError


C = TypeVar("C", bound="ClassificationElement")
CLASSIFICATION_LABEL_T = Any
CLASSIFICATION_VALUE_T = float
# For RO input types
CLASSIFICATION_MAP_T = Mapping[CLASSIFICATION_LABEL_T, CLASSIFICATION_VALUE_T]
# For concrete output types
CLASSIFICATION_DICT_T = Dict[CLASSIFICATION_LABEL_T, CLASSIFICATION_VALUE_T]
NEG_INF = float('-inf')


[docs] class ClassificationElement(Plugfigurable): """ Classification result encapsulation. Contains a mapping of arbitrary (but hashable) label values to confidence values (floating point in ``[0,1]`` range). If a classifier does not produce continuous confidence values, it may instead assign a value of ``1.0`` to a single label, and ``0.0`` to the rest. UUIDs must maintain unique-ness when transformed into a string. Element equality based on classification labels and values, not the type or UUID. Since this base class defines ``__getstate__`` and ``__setstate__`` methods implementing classes must also extend these methods to support serialization. These methods have been marked as abstract to facilitate this requirement. :param type_name: Name of the type of classifier this classification was generated by. :param uuid: Unique ID reference of the classification """ __slots__ = ('type_name', 'uuid') def __init__(self, type_name: str, uuid: Hashable): super(ClassificationElement, self).__init__() self.type_name = type_name self.uuid = uuid def __hash__(self) -> int: return hash((self.type_name, self.uuid)) def __eq__(self, other: Any) -> bool: if isinstance(other, ClassificationElement): a: Optional[CLASSIFICATION_DICT_T] b: Optional[CLASSIFICATION_DICT_T] try: a = self.get_classification() except NoClassificationError: a = None try: b = other.get_classification() except NoClassificationError: b = None return a == b return False def __ne__(self, other: Any) -> bool: return not (self == other) def __repr__(self) -> str: return "%s{type_name: %s, uuid: %s}" \ % (self.__class__.__name__, self.type_name, self.uuid) def __getitem__(self, label: Hashable) -> float: """ Get the confidence value for a specific label **Note on type annotation:** We are using `Hashable` here instead of `CLASSIFICATION_VALUE_T` (`Any`) due to the detachment from the dictionary type. When attached, `Any` is effectively `Hashable` in context of dictionary keys. Here however we want to try to make sure the given value is applicable with a dictionary. :param label: Classification label to get the confidence value for. :raises KeyError: The given label is not present in this classification. :raises NoClassificationError: No classification labels/confidences yet set. :return: Confidence value for the given label. """ return self.get_classification()[label] def __nonzero__(self) -> bool: """ A ClassificationElement is considered non-zero if ``has_classifications`` returns True. See method documentation for details. :return: True if this instance is non-zero (see above), false otherwise. :rtype: bool """ return self.has_classifications() __bool__ = __nonzero__
[docs] @classmethod def get_default_config(cls) -> Dict[str, Any]: """ Generate and return a default configuration dictionary for this class. This will be primarily used for generating what the configuration dictionary would look like for this class without instantiating it. By default, we observe what this class's constructor takes as arguments, turning those argument names into configuration dictionary keys. If any of those arguments have defaults, we will add those values into the configuration dictionary appropriately. The dictionary returned should only contain JSON compliant value types. It is not be guaranteed that the configuration dictionary returned from this method is valid for construction of an instance of this class. :return: Default configuration dictionary for the class. :rtype: dict """ # similar to parent impl, except we remove the ``type_str`` and # ``uuid`` configuration parameters as they are to be specified at # runtime. dc = super(ClassificationElement, cls).get_default_config() # These parameters must be specified at construction time. del dc['type_name'], dc['uuid'] return dc
# noinspection PyMethodOverriding
[docs] @classmethod def from_config( # type: ignore cls: Type[C], config_dict: Dict, type_name: str, uuid: Hashable, merge_default: bool = True ) -> C: """ Instantiate a new instance of this class given the configuration JSON-compliant dictionary encapsulating initialization arguments. This method should not be called via super unless and instance of the class is desired. :param config_dict: JSON compliant dictionary encapsulating a configuration. :param type_name: Name of the type of classifier this classification was generated by. :param uuid: Unique ID reference of the classification :param merge_default: Merge the given configuration on top of the default provided by ``get_default_config``. :return: Constructed instance from the provided config. """ # Shallow-copy config dict to modify config_dict = dict(config_dict) config_dict['type_name'] = type_name config_dict['uuid'] = uuid return super(ClassificationElement, cls).from_config( config_dict, merge_default=merge_default )
[docs] def max_label(self) -> Hashable: """ Get the label with the highest confidence. **Note on type annotation:** We are using `Hashable` here instead of `CLASSIFICATION_VALUE_T` (`Any`) due to the detachment from the dictionary type. When attached, `Any` is effectively `Hashable` in context of dictionary keys. Here however we want to try to make sure the given value is applicable with a dictionary. :raises NoClassificationError: No classification set. :return: The label with the highest confidence. """ # Temp (label, confidence) tuple m: Tuple[Optional[Hashable], float] = (None, NEG_INF) for i in self.get_classification().items(): if i[1] > m[1]: m = i if m[0] is None: raise NoClassificationError("No classifications set to pick the " "max of.") return m[0]
# # Abstract methods # @abc.abstractmethod def __getstate__(self) -> Any: return self.type_name, self.uuid @abc.abstractmethod def __setstate__(self, state: Any) -> None: self.type_name, self.uuid = state
[docs] @abc.abstractmethod def has_classifications(self) -> bool: """ :return: If this element has classification information set. """
[docs] @abc.abstractmethod def get_classification(self) -> CLASSIFICATION_DICT_T: """ Get classification result map, returning a label-to-confidence dict. We do no place any guarantees on label value types as they may be represented in various forms (integers, strings, etc.). Confidence values are in the [0,1] range. :raises NoClassificationError: No classification labels/confidences yet set. :return: Label-to-confidence dictionary. """
[docs] @abc.abstractmethod def set_classification( self, m: Optional[CLASSIFICATION_MAP_T] = None, **kwds: CLASSIFICATION_VALUE_T ) -> CLASSIFICATION_DICT_T: """ Set the whole classification map for this element. This will strictly overwrite the entire label-confidence mapping (vs. updating it) Label/confidence values may either be provided via keyword arguments or by providing a dictionary mapping labels to confidence values. Non-string labels must be provided via an input dictionary (``m`` parameter). NOTE TO IMPLEMENTORS: This abstract method will aggregate input into a single dictionary, checks that there is anything in it and return it. Thus, a ``super`` call should be made, which will return a dictionary. :param m: New labels-to-confidence mapping to set. :raises ValueError: The given label-confidence map was empty. :returns: Input/combined Hashable-to-float mapping as a new dictionary. """ # TODO: Use template method pattern, create ``_set_classification`` # abstract method (removing abstract from this). # Shallow-copy mapping so we may update without impacting input. m = dict(m or {}) m.update(kwds) if not m: raise ValueError("No classification labels/values given.") return m