Commit 3f358bb1 authored by Jindřich Sedláček's avatar Jindřich Sedláček
Browse files

Added support for custom __init__ and __repr__

parent 19e52877
Loading
Loading
Loading
Loading
+71 −31
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ TODO: allow custom init, even if a little bit clunky
from __future__ import annotations
from typing import TypeVar, List
from inspect import get_annotations
from types import FunctionType

T = TypeVar('T')

@@ -24,7 +25,7 @@ class SafeClass:
        pass


def safeclass(cls: type) -> type:
def _process_safeclass(cls, creator=False) -> type:
    var_names: List[str] = list()

    cls_annotations = get_annotations(cls)
@@ -33,13 +34,16 @@ def safeclass(cls: type) -> type:
            try:
                cls_annotations[name] = eval(annotation_val)
            except NameError:
                # we do not have enough information, we have
                # if we do not have enough information, we have
                # to be permissive
                cls_annotations[name] = object

    for name in cls.__annotations__:
        var_names.append(name)

    cls_dict = dict(cls.__dict__)

    if '__init__' not in cls_dict:
        def _init(self, *args, **kwargs):
            for i, value in enumerate(args):
                if i >= len(var_names):
@@ -61,20 +65,52 @@ def safeclass(cls: type) -> type:
                    raise WrongInit("Argument types do not match")

                super(self.__class__, self).__setattr__(name, value)
    else:
        orig_init = cls_dict['__init__']

        class _Obj:
            __slots__ = var_names

            def __init__(self, *args, **kwargs):
                orig_init(self, *args, **kwargs)

        def _init(self, *args, **kwargs):
            _tmp = _Obj(*args, **kwargs)
            for name in var_names:
                value = getattr(_tmp, name)
                super(self.__class__, self).__setattr__(name, value)

    cls_dict['__init__'] = _init

    def _setattr(self, name: str, val: 'T'):
        if name.startswith('_'):
            raise AttributeError("Cannot modify an immutable attribute")
        super(self.__class__, self).__setattr__(name, val)

    cls_dict = dict(cls.__dict__)
    if '__repr__' not in cls_dict:
        def _repr(self) -> str:
            attr_reps = list[str]()
            for i, name in enumerate(self.__class__.__annotations__):
                attr_reps.append(f"{name} = {getattr(self, name)}")
            return f"{self.__class__.__name__}({', '.join(attr_reps)})"

        cls_dict['__repr__'] = _repr

    cls_dict['__slots__'] = var_names
    cls_dict['__init__'] = _init
    cls_dict['__setattr__'] = _setattr

    return type(cls.__name__, cls.__bases__, cls_dict)


def safeclass(cls=None):
    def wrap(cls):
        return _process_safeclass(cls)

    if cls is None:
        return wrap
    return wrap(cls)


def test() -> None:
    @safeclass
    class Person(SafeClass):
@@ -83,22 +119,26 @@ def test() -> None:

    @safeclass
    class Point(SafeClass):
        x: int
        y: int
        _x: int
        _y: int

        def __repr__(self) -> str:
            return f"Point(x = {self._x}, y = {self._y})"

        def __str__(self) -> str:
            return f"Point(x = {self.x}, y = {self.y})"
        def __init__(self, x: int, y: int):
            self._x = x + 1
            self._y = y + 1

    @safeclass
    class Line(SafeClass):
        a: Point
        b: Point

        def __str__(self) -> str:
            return f"Line(a={self.a}, b={self.b})"
    line = Line(Point(0, 1), Point(2, 3))
    print(line)

    line = Line(Point(0, 0), Point(2, 3))
    # print(line)
    a = Point(0, 0)
    print(a)

    p = Person("Joe", 30)