Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,13 @@ def _get_typestr(leaf_type: Any) -> str:

# register standardard v1 leaf handlers to the v0 type handler registry.
handlers = []
for leaf_type, _, leaf_handler_type in leaf_handler_registry.get_all():
# We must reverse the order of the leaf handlers to ensure that the last
# registered handler is the first one used as V1 registry is ordered by
# priority of generic to specific, while V0 type handler registry is ordered
# by the reverse.
for leaf_type, _, leaf_handler_type in reversed(
leaf_handler_registry.get_all()
):
try:
leaf_handler = leaf_handler_type(context=context) # pytype: disable=wrong-keyword-args
except TypeError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Leaf Handler Registry."""

from typing import Any, Dict, Sequence, Tuple, Type
from collections.abc import Sequence
import dataclasses
from typing import Any

from absl import logging
import jax
Expand Down Expand Up @@ -59,6 +61,47 @@
}


@dataclasses.dataclass
class _Registration:
"""A registration entry for a LeafHandler.

Attributes:
leaf_type: The concrete PyTree leaf type.
abstract_type: The abstract representation of the leaf type.
handler_type: The LeafHandler class.
secondary_typestrs: Optional alternate identifiers for the handler.
leaf_specificity_score: Specificity score for the leaf type. Higher value
means more specific type relative to other leaf types which it is a
subclass of. This determines which handler we resolve to during
save/load operations.
abstract_specificity_score: Specificity score for the abstract type. Higher
value means more specific type relative to other abstract types which it
is a subprotocol/subclass of. This determines which handler we resolve to
during save/load operations.
"""

leaf_type: type[Any]
abstract_type: type[Any]
handler_type: type[types.LeafHandler[Any, Any]]
secondary_typestrs: Sequence[str] | None
leaf_specificity_score: int
abstract_specificity_score: int


def _is_abstract_subprotocol(
type_a: type[Any], type_b: type[Any]
) -> bool:
"""Checks if 'type_a' is a subclass or sub-protocol of 'type_b'."""
try:
if typing_extensions.is_protocol(type_b): # pytype: disable=not-supported-yet
return protocol_utils.is_subclass_protocol(
cls=type_a, protocol=type_b
)
return issubclass(type_a, type_b)
except TypeError:
return False


class BaseLeafHandlerRegistry:
"""Base Leaf Handler Registry implements the LeafHandlerRegistry Protocol.

Expand Down Expand Up @@ -87,71 +130,57 @@ class CustomArray(np.ndarray): pass
"""

def __init__(self):
self._leaf_type_registry: Dict[
Type[Any], Type[types.LeafHandler[Any, Any]]
] = {}
self._abstract_type_registry: Dict[
Type[Any], Type[types.LeafHandler[Any, Any]]
] = {}

# for easy look up for replacement
self._handler_to_types: Dict[
Type[types.LeafHandler[Any, Any]], Tuple[Type[Any], Type[Any]]
] = {}
self._secondary_typestrs: Dict[
Type[types.LeafHandler[Any, Any]], Sequence[str]
] = {}
# Sorted [Generic -> Specific] primarily by leaf_specificity_score.
self._entries: list[_Registration] = []

def _try_get(
self, leaf_type: Type[types.Leaf]
) -> Type[types.LeafHandler[types.Leaf, Any]] | None:
"""Returns the handler registered for a given type, if available."""
for registered_ty, handler_type in self._leaf_type_registry.items():
if issubclass(leaf_type, registered_ty):
return handler_type

# no handler found
self, leaf_type: type[types.Leaf]
) -> type[types.LeafHandler[types.Leaf, Any]] | None:
"""Returns the most specific handler for a given type, if available."""
# self._entries is sorted Generic -> Specific by leaf_specificity_score.
# Iterating reversed checks the most specific handlers first.
for entry in reversed(self._entries):
try:
if issubclass(leaf_type, entry.leaf_type):
return entry.handler_type
except TypeError:
pass
return None

def get(
self, leaf_type: Type[types.Leaf]
) -> Type[types.LeafHandler[types.Leaf, Any]]:
self, leaf_type: type[types.Leaf]
) -> type[types.LeafHandler[types.Leaf, Any]]:
if (handler_type := self._try_get(leaf_type)) is None:
raise ValueError(
f'Unknown Leaf type: "{leaf_type}". Must register it with'
f'Unknown Leaf type: "{leaf_type!r}". Must register it with'
' LeafHandlerRegistry.'
)

return handler_type

def _try_get_abstract(
self,
abstract_type: Type[types.AbstractLeaf],
) -> Type[types.LeafHandler[Any, types.AbstractLeaf]] | None:
"""Returns the handler registered for a given abstract type, if available."""
for (
registered_abstract_ty,
handler_type,
) in self._abstract_type_registry.items():
if typing_extensions.is_protocol(registered_abstract_ty): # pytype: disable=not-supported-yet
if protocol_utils.is_subclass_protocol(
cls=abstract_type, protocol=registered_abstract_ty
):
return handler_type
elif issubclass(abstract_type, registered_abstract_ty):
return handler_type

# no handler found
abstract_type: type[types.AbstractLeaf],
) -> type[types.LeafHandler[Any, types.AbstractLeaf]] | None:
"""Returns the most specific handler for a given abstract type."""
# Sort ascending by abstract_specificity_score (lowest to highest).
sorted_entries = sorted(
self._entries,
key=lambda e: e.abstract_specificity_score
)
# Iterating reversed checks the most specific handlers first.
for entry in reversed(sorted_entries):
if _is_abstract_subprotocol(abstract_type, entry.abstract_type):
return entry.handler_type
return None

def get_abstract(
self,
abstract_type: Type[types.AbstractLeaf],
) -> Type[types.LeafHandler[Any, types.AbstractLeaf]]:
abstract_type: type[types.AbstractLeaf],
) -> type[types.LeafHandler[Any, types.AbstractLeaf]]:
if (handler_type := self._try_get_abstract(abstract_type)) is None:
raise ValueError(
f'Unknown AbstractLeaf type: "{abstract_type}". Must register it with'
' LeafHandlerRegistry.'
f'Unknown AbstractLeaf type: "{abstract_type!r}". Must register it'
' with LeafHandlerRegistry.'
)

return handler_type
Expand All @@ -167,24 +196,32 @@ def get_all(
"""
return [
(
leaf_type,
abstract_type,
handler_type,
)
for (leaf_type, handler_type), abstract_type in zip(
self._leaf_type_registry.items(), self._abstract_type_registry
entry.leaf_type,
entry.abstract_type,
entry.handler_type,
)
for entry in self._entries
]

def add(
self,
leaf_type: Type[types.Leaf],
abstract_type: Type[types.AbstractLeaf],
handler_type: Type[types.LeafHandler[types.Leaf, types.AbstractLeaf]],
leaf_type: type[types.Leaf],
abstract_type: type[types.AbstractLeaf],
handler_type: type[types.LeafHandler[types.Leaf, types.AbstractLeaf]],
override: bool = False,
secondary_typestrs: Sequence[str] | None = None,
):
"""Adds a handler_type for a given leaf_type and abstract_type pair.
"""Registers a `handler_type` for a `leaf_type` and `abstract_type` pair.

The registry automatically maintains a [Generic -> Specific] hierarchy for
both leaf and abstract types using dynamic topological priorities to ensure
correct resolution. We maintain and recalculate these specificity scores to
ensure that the most specific handler is chosen during resolution.

A conflict occurs if the exact `leaf_type` is already registered, or if the
`abstract_type` is already mapped to a different handler. Set
`override=True` to automatically remove conflicting entries and force the
new registration.

Args:
leaf_type: The concrete PyTree leaf type to register.
Expand All @@ -196,56 +233,110 @@ def add(
secondary identifiers for the handler.

Raises:
ValueError: If the `leaf_type` or `abstract_type` is already registered
and `override` is False. Also raised if the `abstract_type` is already
registered with a fundamentally different handler type.
ValueError: If a duplicate `leaf_type` or conflicting `abstract_type`
mapping exists and `override` is False.
"""
current_handler_type = self._try_get(leaf_type)
current_abstract_handle_type = self._try_get_abstract(abstract_type)

if not override and (current_handler_type or current_abstract_handle_type):
raise ValueError(
f'Leaf_type[{leaf_type}] or abstract_type[{abstract_type}] has'
f' already registered, current_handler: {current_handler_type}, '
f'current_abstract_handle_type: {current_abstract_handle_type}'
)

logging.vlog(
1,
'add: leaf_type[%s], abstract_type[%s], handler_type[%s],'
' current_handler[%s], current_abstract_handle_type[%s]',
# Check for exact duplicate registration
for e in self._entries:
if (
e.leaf_type == leaf_type
and e.abstract_type == abstract_type
and e.handler_type == handler_type
):
logging.info(
'Registration already exists for leaf_type[%s], '
'abstract_type[%s], handler_type[%s]. Skipping.',
leaf_type,
abstract_type,
handler_type,
)
return

if override:
# Filter out conflicting entries if override is True.
new_entries = []
for e in self._entries:
is_conflict = (e.leaf_type == leaf_type) or (
e.abstract_type == abstract_type and e.handler_type != handler_type
)
if is_conflict:
logging.warning(
'clearing conflicting entry: leaf_type[%s], abstract_type[%s]'
' handler_type[%s] during override.',
e.leaf_type,
e.abstract_type,
e.handler_type,
)
else:
new_entries.append(e)
self._entries = new_entries
else:
for e in self._entries:
if e.leaf_type == leaf_type:
raise ValueError(
f'leaf_type [{leaf_type}] is already handled by '
f'{e.handler_type}. Use override=True to replace its entry. '
f'Registry: {self._entries}'
)
if e.abstract_type == abstract_type and e.handler_type != handler_type:
raise ValueError(
f'abstract_type[{abstract_type}] is already handled by '
f'{e.handler_type}. Use override=True to replace all '
f'conflicting entries. Registry: {self._entries}'
)

# Append the new entry with default priorities
new_reg = _Registration(
leaf_type,
abstract_type,
handler_type,
current_handler_type,
current_abstract_handle_type,
secondary_typestrs,
leaf_specificity_score=0,
abstract_specificity_score=0,
)
self._entries.append(new_reg)
# Recalculate specificity scores for all entries since new entry was added
# and may change the specificity scores of existing entries.
self._recalculate_specificity_scores()

# Sort the single source of truth [Generic -> Specific] based on leaf type
# primarily, and abstract type secondarily.
self._entries.sort(
key=lambda x: (
x.leaf_specificity_score,
x.abstract_specificity_score,
x.handler_type.__name__,
)
)

if current_handler_type and (
current_abstract_handle_type
and current_handler_type != current_abstract_handle_type
):
raise ValueError(
f'Abstract_type[{abstract_type}] has already registered with a'
' different type.'
)
elif current_handler_type and not current_abstract_handle_type:
# need to remove the previous abstract type
_, old_abstract_ty = self._handler_to_types.pop(current_handler_type)
self._abstract_type_registry.pop(old_abstract_ty)

# new type and abstract type pair
self._leaf_type_registry[leaf_type] = handler_type
self._abstract_type_registry[abstract_type] = handler_type
self._handler_to_types[handler_type] = (leaf_type, abstract_type)
# Allows for multiple handlers to be associated with the same leaf_type and
# abstract_type pair, typically for backward compatibility.
if secondary_typestrs is not None:
self._secondary_typestrs[handler_type] = (
secondary_typestrs
)
def _recalculate_specificity_scores(self) -> None:
"""Recalculates specificity scores and sorts the registry."""
for target_entry in self._entries:
leaf_count = 0
abstract_count = 0
for other_entry in self._entries:
# Count how many leaf types this target is a subclass of.
try:
if (
target_entry.leaf_type != other_entry.leaf_type and
issubclass(target_entry.leaf_type, other_entry.leaf_type)
):
leaf_count += 1
except TypeError:
pass
# Count how many abstract types this target is a subprotocol of.
if (
target_entry.abstract_type != other_entry.abstract_type and
_is_abstract_subprotocol(
target_entry.abstract_type, other_entry.abstract_type
)
):
abstract_count += 1
target_entry.leaf_specificity_score = leaf_count
target_entry.abstract_specificity_score = abstract_count

def is_handleable(self, leaf_type: Type[Any]) -> bool:
def is_handleable(self, leaf_type: type[Any]) -> bool:
"""Returns True if the type is handleable.

This checks if the provided concrete leaf type, or any of its base classes,
Expand All @@ -259,8 +350,8 @@ def is_handleable(self, leaf_type: Type[Any]) -> bool:
"""
return self._try_get(leaf_type) is not None

def is_abstract_handleable(self, abstract_type: Type[Any]) -> bool:
"""Returns True if the abstract type is handlable.
def is_abstract_handleable(self, abstract_type: type[Any]) -> bool:
"""Returns True if the abstract type is handleable.

This checks if the provided abstract leaf type, or any of its matching base
classes or protocols, has a registered handler in the registry.
Expand All @@ -274,9 +365,12 @@ def is_abstract_handleable(self, abstract_type: Type[Any]) -> bool:
return self._try_get_abstract(abstract_type) is not None

def get_secondary_typestrs(
self, handler_type: Type[types.LeafHandler[Any, Any]]
self, handler_type: type[types.LeafHandler[Any, Any]]
) -> Sequence[str]:
return self._secondary_typestrs.get(handler_type, [])
for entry in self._entries:
if entry.handler_type == handler_type:
return entry.secondary_typestrs or []
return []


class StandardLeafHandlerRegistry(BaseLeafHandlerRegistry):
Expand Down
Loading
Loading