# Copyright 2023 sepes authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constructor code generation from type annotations."""
# this modules contains functionality to turn type hints into a constructor
# similar to `dataclasses.dataclass`/`attrs`
# however, notable differences are:
# - Fields are not tied to the class decorator. i.e. `Field` can be used without `autoinit`.
# - Fields enable running function callbacks on the field values during setting/getting.
# using descriptors and emit a descriptive error message in case of an error.
# - Marking fields as positional only, keyword only, variable positional,...
# - Does not allow mutable defaults.
# - Registering additional types to be excluded from `autoinit`. e.g. raise an error.
# - Only code generation is supported is done. other functionality like `__repr__`,
# `__eq__`, `__hash__`,are not supported.
# one design choice is that `autoinit` and `Field` are not tightly coupled.
# Field` can be used without `autoinit` as a descriptor to apply functions on
# the field values during initialization. Moreover, `TreeClass` is not coupled with
# `autoinit` or `Field` and can be used without them. this simplifies the code
# by separating the functionality.
from __future__ import annotations
import functools as ft
import sys
from collections import defaultdict
from collections.abc import Callable, MutableMapping, MutableSequence, MutableSet
from typing import Any, Literal, Sequence, TypeVar, get_args
from warnings import warn
from weakref import WeakSet
from typing_extensions import dataclass_transform
T = TypeVar("T")
PyTree = Any
EllipsisType = type(Ellipsis)
KindType = Literal["POS_ONLY", "POS_OR_KW", "VAR_POS", "KW_ONLY", "VAR_KW", "CLASS_VAR"]
arg_kinds: tuple[str, ...] = get_args(KindType)
EXCLUDED_FIELD_NAMES: set[str] = {"self", "__post_init__", "__annotations__"}
_autoinit_registry: WeakSet[type] = WeakSet()
class Null:
__slots__ = []
__repr__ = lambda _: "NULL"
__bool__ = lambda _: False
NULL = Null()
def generate_field_doc(field: Field) -> str:
out: list[str] = ["Field Information:"]
out += [f"\tName:\t\t``{field.name}``"]
out += [f"\tDefault:\t``{field.default}``"] if field.default is not NULL else []
out += [f"Description:\n\t{field.doc}"] if field.doc else []
if field.on_setattr or field.on_getattr:
out += ["Callbacks:"]
if field.on_setattr:
out += ["\t- On setting attribute:\n"]
out += [f"\t\t- ``{func}``" for func in field.on_setattr]
if field.on_getattr:
out += ["\t- On getting attribute:\n"]
out += [f"\t\t- ``{func}``" for func in field.on_getattr]
return "\n".join(out)
def slots(klass) -> tuple[str, ...]:
return getattr(klass, "__slots__", ())
def pipe(funcs: Sequence[Callable[[Any], Any]], name: str | None, value: Any):
"""Apply a sequence of functions on the field value."""
for func in funcs:
# for a given sequence of unary functions, apply them on the field value
# and return the result. if an error is raised, emit a descriptive error
try:
value = func(value)
except Exception as e:
# emit a *descriptive* error message with the name of the attribute
# associated with the field and the name of the function that raised
# the error.
cname = getattr(func, "__name__", func)
raise type(e)(f"On applying {cname} for field=`{name}`:\n{e}")
return value
class Field:
"""Field descriptor placeholder"""
__slots__ = [
"name",
"type",
"default",
"init",
"repr",
"kind",
"metadata",
"on_setattr",
"on_getattr",
"alias",
"doc",
]
def __init__(
self,
*,
name: str | Null = NULL,
type: type | Null = NULL,
default: Any = NULL,
init: bool = True,
repr: bool = True,
kind: KindType = "POS_OR_KW",
metadata: dict[str, Any] | None = None,
on_setattr: Sequence[Callable[[Any], Any]] = (),
on_getattr: Sequence[Callable[[Any], Any]] = (),
alias: str | None = None,
doc: str = "",
):
self.name = name
self.type = type
self.default = default
self.init = init
self.repr = repr
self.kind = kind
self.metadata = metadata
self.on_setattr = on_setattr
self.on_getattr = on_getattr
self.alias = alias
self.doc = doc
def replace(self, **kwargs) -> Field:
"""Replace the field attributes."""
# define a `replace` method similar to `dataclasses.replace` or namedtuple
# to allow the user to replace the field attributes.
return type(self)(**{k: kwargs.get(k, getattr(self, k)) for k in slots(Field)})
def pipe_on_setattr(self, value: Any) -> Any:
"""Apply a sequence of functions on the field value during setting."""
return pipe(self.on_setattr, self.name, value)
def pipe_on_getattr(self, value: Any) -> Any:
"""Apply a sequence of functions on the field value during getting."""
return pipe(self.on_getattr, self.name, value)
def __set_name__(self, owner, name: str) -> None:
"""Set the field name."""
# set the name of the field to the attribute name in the class
# and the type to the type hint of the attribute if it exists
self.name = name
# in case the user uses `field` as a descriptor without annotating the class
if "__annotations__" in (variables := vars(owner)):
# set the type to the type hint of the attribute if it exists
self.type = variables.get(name, NULL)
@property
def __doc__(self) -> str:
"""Return the field documentation."""
return generate_field_doc(field=self)
def __get__(self: T, instance, _) -> T | Any:
"""Return the field value."""
if instance is None:
return self
return self.pipe_on_getattr(vars(instance)[self.name])
def __set__(self: T, instance, value) -> None:
"""Set the field value."""
vars(instance)[self.name] = self.pipe_on_setattr(value)
def __delete__(self: T, instance) -> None:
"""Delete the field value."""
del vars(instance)[self.name]
[docs]
def field(
*,
default: Any = NULL,
init: bool = True,
repr: bool = True,
kind: KindType = "POS_OR_KW",
metadata: dict[str, Any] | None = None, # type: ignore
on_setattr: Sequence[Any] = (),
on_getattr: Sequence[Any] = (),
alias: str | None = None,
doc: str = "",
) -> Field:
"""Field placeholder for type hinted attributes.
Args:
default: The default value of the field.
init: Whether the field is included in the object's ``__init__`` function.
repr: Whether the field is included in the object's ``__repr__`` function.
kind: Argument kind used in the constructor sythesis with :func:`autoinit`,
- ``POS_ONLY``: positional only argument (e.g. ``x`` in ``def f(x, /):``)
- ``VAR_POS``: variable positional argument (e.g. ``*x`` in ``def f(*x):``)
- ``POS_OR_KW``: positional or keyword argument (e.g. ``x`` in ``def f(x):``)
- ``KW_ONLY``: keyword only argument (e.g. ``x`` in ``def f(*, x):``)
- ``VAR_KW``: variable keyword argument (e.g. ``**x`` in ``def f(**x):``)
- ``CLASS_VAR``: Non-constructor class variable (e.g. ``x`` in ``class C: x = 1``)
metadata: A mapping of user-defined data for the field.
on_setattr: A sequence of functions to called on ``__setattr__``.
on_getattr: A sequence of functions to called on ``__getattr__``.
alias: An a alias for the field name in the constructor. e.g ``name=x``,
``alias=y`` will allow ``obj = Class(y=1)`` to be equivalent to
``obj = Class(x=1)``.
doc: extra documentation for the :func:.`field` .the complete documentation
of the field includes the field name, the field doc, and the
default value, and function callbacks applied on the field value.
Mainly used for documenting the field callbacks.
.. code-block:: python
>>> import sepes as sp
>>> @sp.autoinit
... class Tree:
... leaf: int = sp.field(
... default=1,
... doc="Leaf node of the tree.",
... on_setattr=[lambda x: x],
... )
>>> print(Tree.leaf.__doc__) # doctest: +SKIP
Field Information:
Name: ``leaf``
Default: ``1``
Description:
Leaf node of the tree.
Callbacks:
- On setting attribute:
- ``<function Tree.<lambda> at 0x11c53dc60>``
Example:
Type and range validation using :attr:`on_setattr`:
>>> import sepes as sp
>>> @sp.autoinit
... class IsInstance(sp.TreeClass):
... klass: type
... def __call__(self, x):
... assert isinstance(x, self.klass)
... return x
<BLANKLINE>
>>> @sp.autoinit
... class Range(sp.TreeClass):
... start: int|float = float("-inf")
... stop: int|float = float("inf")
... def __call__(self, x):
... assert self.start <= x <= self.stop
... return x
<BLANKLINE>
>>> @sp.autoinit
... class Employee(sp.TreeClass):
... # assert employee ``name`` is str
... name: str = sp.field(on_setattr=[IsInstance(str)])
... # use callback compostion to assert employee ``age`` is int and positive
... age: int = sp.field(on_setattr=[IsInstance(int), Range(1)])
>>> employee = Employee(name="Asem", age=10)
>>> print(employee)
Employee(name=Asem, age=10)
Example:
Private attribute using :attr:`alias`:
>>> import sepes as sp
>>> @sp.autoinit
... class Employee(sp.TreeClass):
... # `alias` is the name used in the constructor
... _name: str = sp.field(alias="name")
>>> employee = Employee(name="Asem") # use `name` in the constructor
>>> print(employee) # `_name` is the private attribute name
Employee(_name=Asem)
Example:
Buffer creation using :attr:`on_getattr`:
>>> import sepes as sp
>>> import jax
>>> import jax.numpy as jnp
>>> @sp.autoinit
... class Tree(sp.TreeClass):
... buffer: jax.Array = sp.field(on_getattr=[jax.lax.stop_gradient])
>>> tree = Tree(buffer=jnp.array((1.0, 2.0)))
>>> def sum_buffer(tree):
... return tree.buffer.sum()
>>> print(jax.grad(sum_buffer)(tree)) # no gradient on `buffer`
Tree(buffer=[0. 0.])
Example:
Parameterization using :attr:`on_getattr`:
>>> import sepes as sp
>>> import jax
>>> import jax.numpy as jnp
>>> def symmetric(array: jax.Array) -> jax.Array:
... triangle = jnp.triu(array) # upper triangle
... return triangle + triangle.transpose(-1, -2)
>>> @sp.autoinit
... class Tree(sp.TreeClass):
... symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])
>>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
>>> print(tree.symmetric_matrix)
[[ 0 1 2]
[ 1 8 5]
[ 2 5 16]]
Note:
- :func:`field` is commonly used to annotate the class attributes to be
used by the :func:`autoinit` decorator to generate the ``__init__``
method similar to ``dataclasses.dataclass``.
- :func:`field` can be used without the :func:`autoinit` as a descriptor
to apply functions on the field values during initialization using
the ``on_setattr`` / ``on_getattr`` argument.
>>> import sepes as sp
>>> def print_and_return(x):
... print(f"Setting {x}")
... return x
>>> class Tree:
... # `a` must be defined as a class attribute for the descriptor to work
... a: int = sp.field(on_setattr=[print_and_return])
... def __init__(self, a):
... self.a = a
>>> tree = Tree(1)
Setting 1
"""
if not isinstance(alias, (str, type(None))):
raise TypeError(f"Non-string {alias=} argument provided to `field`")
if not isinstance(metadata, (dict, type(None))):
raise TypeError(f"Non-dict {metadata=} argument provided to `field`")
if kind not in arg_kinds:
raise ValueError(f"{kind=} not in {arg_kinds}")
if not isinstance(on_setattr, Sequence):
raise TypeError(f"Non-sequence {on_setattr=} argument provided to `field`")
if not isinstance(on_getattr, Sequence):
raise TypeError(f"Non-sequence {on_getattr=} argument provided to `field`")
if not isinstance(init, bool):
raise TypeError(f"Non-bool {init=} argument provided to `field`")
for func in on_setattr:
if not isinstance(func, Callable): # type: ignore
raise TypeError(f"Non-callable {func=} provided to `field` on_setattr")
for func in on_getattr:
if not isinstance(func, Callable):
raise TypeError(f"Non-callable {func=} provided to `field` on_getattr")
return Field(
default=default,
init=init,
repr=repr,
kind=kind,
metadata=metadata, # type: ignore
on_setattr=on_setattr,
on_getattr=on_getattr,
alias=alias,
doc=doc,
)
def build_field_map(klass: type) -> dict[str, Field]:
field_map: dict[KindType, Field] = dict()
if klass is object:
return dict(field_map)
for base in reversed(klass.__mro__[1:]):
field_map.update(build_field_map(base))
if (hint_map := vars(klass).get("__annotations__", NULL)) is NULL:
# not annotated
return dict(field_map)
if EXCLUDED_FIELD_NAMES.intersection(hint_map):
raise ValueError(f"`Field` in {EXCLUDED_FIELD_NAMES=}")
for key, hint in hint_map.items():
if isinstance(value := vars(klass).get(key, NULL), Field):
# case: `x: Any = field(default=1)`
field_map[key] = value.replace(name=key, type=hint)
return field_map
[docs]
def fields(x: Any) -> tuple[Field, ...]:
"""Returns a tuple of ``Field`` objects for the given instance or class.
``Field`` objects are generated from the class type hints and contains
the information about the field information.if the user uses
the ``sepes.field`` to annotate.
Note:
- If the class is not annotated, an empty tuple is returned.
- The ``Field`` generation is cached for class and its bases.
"""
return tuple(build_field_map(x if isinstance(x, type) else type(x)).values())
def convert_hints_to_fields(klass: type[T]) -> type[T]:
# convert klass hints to `Field` objects for the **current** class
if (hint_map := vars(klass).get("__annotations__", NULL)) is NULL:
# in case no type hints are provided, return the class as is
return klass
for key, hint in hint_map.items():
if isinstance(value := vars(klass).get(key, NULL), Field):
if value.kind == "CLASS_VAR":
setattr(klass, key, value.default)
continue
# no need to convert `Field` annotated attributes again
setattr(klass, key, Field(default=value, type=hint, name=key))
return klass
def check_excluded_types(field_map: dict[KindType, Field]) -> dict[KindType, Field]:
# check if the user uses excluded types in `autoinit`
# like mutable types
for key in field_map:
excluded_type_dispatcher(field_map[key].default)
return field_map
def check_duplicate_var_kind(field_map: dict[KindType, Field]) -> dict[KindType, Field]:
# check for duplicate `VAR_POS` and `VAR_KW` arguments
seen: set[Literal["VAR_POS", "VAR_KW"]] = set()
for field in field_map.values():
if field.kind in ("VAR_POS", "VAR_KW"):
if field.kind in seen:
raise TypeError(f"Duplicate {field.kind=} for {field.name=}")
seen.add(field.kind)
return field_map
def check_order_of_args(field_map: dict[KindType, Field]) -> dict[KindType, Field]:
# check if the order of the arguments is valid
# otherwise raise a warning to the user to acknowledge the change
# in the order of the arguments
# for reference `dataclasses.dataclass` does not warn the user before
# reordering the arguments
seen: list[KindType] = []
for key in field_map:
if field_map[key].kind == "CLASS_VAR":
continue
seen += [field_map[key].kind]
if len(seen) > 1 and arg_kinds.index(seen[-2]) > arg_kinds.index(seen[-1]):
warn(f"Kind order {seen} order != {arg_kinds} and will be reordered.")
return field_map
def build_init_method(klass: type[T]) -> type[T]:
field_map: dict[KindType, Field] = build_field_map(klass)
field_map = check_excluded_types(field_map)
field_map = check_duplicate_var_kind(field_map)
field_map = check_order_of_args(field_map)
hints = {"return": None} # annotations
body: list[str] = []
head: list[str] = ["self"]
heads: dict[KindType, list[str]] = defaultdict(list)
for field in field_map.values():
if field.kind == "CLASS_VAR":
# skip class variables from init synthesis
# e.g. class A: x = field(default=1, kind="CLASS_VAR")
continue
if field.init:
# add to field to head and body
hints[field.name] = field.type
# how to name the field in the constructor
alias = field.alias or field.name
body += [f"self.{field.name}={alias}"]
if field.default is NULL:
# e.g. def __init__(.., x)
heads[field.kind] += [alias]
else:
# e.g def __init__(.., x=value) but
# pass reference to the default value
heads[field.kind] += [f"{alias}=refmap['{field.name}'].default"]
else:
if field.default is not NULL:
# case for fields with `init=False` and no default value
# usaully declared in __post_init__
body += [f"self.{field.name}=refmap['{field.name}'].default"]
# add pass in case all fields are not included in the constructor
# i.e. `init=False` for all fields
body += [f"self.{key}()"] if (key := "__post_init__") in vars(klass) else ["pass"]
# organize the arguments order:
# (POS_ONLY, POS_OR_KW, VAR_POS, KW_ONLY, VAR_KW)
head += (heads["POS_ONLY"] + ["/"]) if heads["POS_ONLY"] else []
head += heads["POS_OR_KW"]
head += ["*" + "".join(heads["VAR_POS"])] if heads["VAR_POS"] else []
# case for ...(*a, b) and ...(a, *, b)
head += ["*"] if (heads["KW_ONLY"] and not heads["VAR_POS"]) else []
head += heads["KW_ONLY"]
head += ["**" + "".join(heads["VAR_KW"])] if heads["VAR_KW"] else []
# generate the code for the method
code = "def closure(refmap):\n"
code += f"\tdef __init__({','.join(head)}):"
field_map["__annotations__"] = hints
code += f"\n\t\t{';'.join(body)}"
code += f"\n\t__init__.__qualname__ = '{klass.__qualname__}.__init__'"
code += "\n\t__init__.__annotations__ = refmap['__annotations__']"
code += "\n\treturn __init__"
# execute the code in the class namespace to generate the method
exec(code, vars(sys.modules[klass.__module__]), namespace := dict())
method = namespace["closure"](field_map)
# add the method to the class
setattr(klass, "__init__", method)
# mark the class as transformed
return klass
[docs]
@dataclass_transform(field_specifiers=(Field, field))
def autoinit(klass: type[T]) -> type[T]:
"""A class decorator that generates the ``__init__`` method from type hints.
Using the ``autoinit`` decorator, the user can define the class attributes
using type hints and the ``__init__`` method will be generated automatically
>>> import sepes as sp
>>> @sp.autoinit
... class Tree:
... x: int
... y: int
Is equivalent to:
>>> class Tree:
... def __init__(self, x: int, y: int):
... self.x = x
... self.y = y
Example:
>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
... x: int
... y: int
>>> inspect.signature(Tree.__init__)
<Signature (self, x: int, y: int) -> None>
>>> tree = Tree(1, 2)
>>> tree.x, tree.y
(1, 2)
Example:
Define fields with different argument kinds
>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
... kw_only_field: int = sp.field(default=1, kind="KW_ONLY")
... pos_only_field: int = sp.field(default=2, kind="POS_ONLY")
>>> inspect.signature(Tree.__init__)
<Signature (self, pos_only_field: int = 2, /, *, kw_only_field: int = 1) -> None>
Example:
Define a converter to apply ``abs`` on the field value
>>> @sp.autoinit
... class Tree:
... a:int = sp.field(on_setattr=[abs])
>>> Tree(a=-1).a
1
.. warning::
The ``autoinit`` decorator will raise ``TypeError`` if the user defines
``__init__`` method in the decorated class.
Note:
- In case of inheritance, the ``__init__`` method is generated from the
the type hints of the current class and any base classes that
are decorated with ``autoinit``.
>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Base:
... x: int
>>> @sp.autoinit
... class Derived(Base):
... y: int
>>> obj = Derived(x=1, y=2)
>>> inspect.signature(obj.__init__)
<Signature (x: int, y: int) -> None>
- Base classes that are not decorated with ``autoinit`` are ignored during
synthesis of the ``__init__`` method.
>>> import sepes as sp
>>> import inspect
>>> class Base:
... x: int
>>> @sp.autoinit
... class Derived(Base):
... y: int
>>> obj = Derived(y=2)
>>> inspect.signature(obj.__init__)
<Signature (y: int) -> None>
Note:
Use ``autoinit`` instead of ``dataclasses.dataclass`` if you want to
use ``jax.Array`` as a field default value. As ``dataclasses.dataclass``
will incorrectly raise an error starting from python 3.11 complaining
that ``jax.Array`` is not immutable.
Note:
By default ``autoinit`` will raise an error if the user uses mutable defaults.
To register an additional type to be excluded from ``autoinit``, use
:func:`autoinit.register_excluded_type`, with an optional ``reason``
for excluding the type.
>>> import sepes as sp
>>> class T:
... pass
>>> sp.autoinit.register_excluded_type(T, reason="not allowed")
>>> @sp.autoinit
... class Tree:
... x: T = sp.field(default=T()) # doctest: +SKIP
Traceback (most recent call last):
...
"""
if klass in _autoinit_registry:
# autoinit(autoinit(klass)) == autoinit(klass)
# idempotent decorator to avoid redefining the class
return klass
if "__init__" in vars(klass):
# if the class already has a user-defined __init__ method
# then raise an error to avoid confusing the user
raise TypeError(f"autoinit({klass.__name__}) with defined `__init__`.")
# first convert the current class hints to fields
# then build the __init__ method from the fields of the current class
# and any base classes that are decorated with `autoinit`
klass = build_init_method(convert_hints_to_fields(klass))
# add the class to the registry to avoid redefining the class
_autoinit_registry.add(klass)
return klass
excluded_type_dispatcher = ft.singledispatch(lambda _: None)
def register_excluded_type(klass: type, reason: str | None = None) -> None:
"""Exclude a type from being used in the ``autoinit`` decorator.
Args:
klass: The type to be excluded.
reason: The reason for excluding the type.
"""
reason = f" {reason=}" if reason is not None else ""
@excluded_type_dispatcher.register(klass)
def _(value) -> None:
raise TypeError(f"{value=} is excluded from `autoinit`.{reason}")
autoinit.register_excluded_type = register_excluded_type
autoinit.register_excluded_type(MutableMapping, reason="mutable type")
autoinit.register_excluded_type(MutableSequence, reason="mutable type")
autoinit.register_excluded_type(MutableSet, reason="mutable type")