Commit cf0e7bfe authored by xdrazkov's avatar xdrazkov
Browse files

feat: add metrics

parent ab84bbbc
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ requires-python = ">=3.12"
dependencies = [
    "hydra-core>=1.3.2",
    "omegaconf>=2.3.0",
    "torch>=2.9.1",
]

[dependency-groups]
+74 −0
Original line number Diff line number Diff line
# Standard Imports
from abc import ABC, abstractmethod
from typing import Any


MetricValueDict = dict[str, Any]


class Metric(ABC):
    _name: str

    def __init__(self, name: str) -> None:
        self._name = name
        self.reset()

    @abstractmethod
    def update(self, *args: Any, **kwargs: Any) -> None: ...

    @abstractmethod
    def compute(self) -> MetricValueDict:
        """Returns the computed metric value. Does not change the state of the metric."""
        ...

    @abstractmethod
    def reset(self) -> None: ...

    @abstractmethod
    def __repr__(self) -> str: ...

    def __str__(self) -> str:
        return self.__repr__()

    @property
    def name(self) -> str:
        return self._name

    def __call__(self, *args: Any, **kwargs: Any) -> MetricValueDict:
        self.update(*args, **kwargs)
        return self.compute()


class MetricsDict:
    def __init__(self, **kwargs):
        self._metrics = {name: cls(name) for name, cls in kwargs.items()}

    def __getitem__(self, key):
        return self._metrics[key]

    def __getattr__(self, key):
        try:
            return self._metrics[key]
        except KeyError:
            raise AttributeError(f"'MetricsDict' has no attribute '{key}'") from None

    def __setattr__(self, key, value):
        if key == "_metrics":
            super().__setattr__(key, value)
        else:
            raise AttributeError("Cannot modify attributes of MetricsDict")

    def __setitem__(self, key, value):
        raise AttributeError("Cannot modify items of MetricsDict")

    def __delattr__(self, key):
        raise AttributeError("Cannot delete attributes of MetricsDict")

    def __delitem__(self, key):
        raise AttributeError("Cannot delete items of MetricsDict")

    def __iter__(self):
        return iter(self._metrics)

    def items(self):
        return self._metrics.items()
+155 −0
Original line number Diff line number Diff line
# Third-party Imports
import math

from torch import Tensor

# Local Imports
from metrics.abstract import Metric


def input_to_float(value: float | Tensor) -> float:
    return float(value.item()) if isinstance(value, Tensor) else value


class MaxMetric(Metric):
    _max: float
    _max_at: int
    _count: int
    _count_max: int

    def update(self, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        if value > self._max:
            self._max = value
            self._max_at = self._count
            self._count_max = 1
        elif value == self._max:
            self._count_max += 1

    def compute(self) -> dict[str, float]:
        return {
            f"max_{self._name}": self._max,
            f"max_{self._name}_at": self._max_at,
            f"max_{self._name}_count": self._count_max,
        }

    def reset(self) -> None:
        self._max = float("-inf")
        self._max_at = 0
        self._count = 0
        self._count_max = 0

    def __repr__(self) -> str:
        return (
            f"max_{self._name}={self._max:.2f} at {self._max_at} ({self._count_max}x)"
        )


class MinMetric(Metric):
    _min: float
    _min_at: int
    _count: int
    _count_min: int

    def update(self, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        if value < self._min:
            self._min = value
            self._min_at = self._count
            self._count_min = 1
        elif value == self._min:
            self._count_min += 1

    def compute(self) -> dict[str, float]:
        return {
            f"min_{self._name}": self._min,
            f"min_{self._name}_at": self._min_at,
            f"min_{self._name}_count": self._count_min,
        }

    def reset(self) -> None:
        self._min = float("inf")
        self._min_at = 0
        self._count = 0
        self._count_min = 0

    def __repr__(self) -> str:
        return (
            f"min_{self._name}={self._min:.2f} at {self._min_at} ({self._count_min}x)"
        )


class MeanMetric(Metric):
    _sum: float
    _count: int

    def update(self, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        self._sum += value

    def compute(self) -> dict[str, float]:
        return {f"avg_{self._name}": self._get_mean()}

    def reset(self) -> None:
        self._sum = 0
        self._count = 0

    def _get_mean(self) -> float:
        return float("nan") if self._count == 0 else self._sum / self._count

    def __repr__(self) -> str:
        return f"avg_{self._name}={self._get_mean():.2f}"


class MeanStdMetric(Metric):
    _sum: float
    _sum_sq: float
    _count: int

    def update(self, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        self._sum += value
        self._sum_sq += value**2

    def compute(self) -> dict[str, float]:
        return {
            f"avg_{self._name}": self._get_mean(),
            f"std_{self._name}": self._get_std(),
        }

    def reset(self) -> None:
        self._sum = 0
        self._sum_sq = 0
        self._count = 0

    def _get_mean(self) -> float:
        return float("nan") if self._count == 0 else self._sum / self._count

    def _get_std(self) -> float:
        if self._count == 0:
            return float("nan")
        return math.sqrt(self._sum_sq / self._count - (self._sum / self._count) ** 2)

    def __repr__(self) -> str:
        return f"avg_{self._name}={self._get_mean():.2f} \u00b1 {self._get_std():.2f}"


class LastMetric(Metric):
    _value: float | None

    def update(self, value: float | Tensor) -> None:
        self._value = input_to_float(value)

    def compute(self) -> dict[str, float]:
        return {f"{self._name}": self._value}

    def reset(self) -> None:
        self._value = None

    def __repr__(self) -> str:
        val_str = "None" if self._value is None else f"{self._value:.2f}"
        return f"{self._name}={val_str}"
+93 −0
Original line number Diff line number Diff line
# Standard Imports
from abc import ABC
from typing import Generic, TypeVar

# Third-party Imports
from torch import Tensor

# Local Imports
from metrics.abstract import Metric
from metrics.aggregation import input_to_float


T = TypeVar("T")


class AttributeStore(Metric, ABC, Generic[T]):
    """<T> is the strategy type."""


class MaxAttributeStore(AttributeStore[T]):
    _max: float
    _max_at: int
    _count: int
    _store: T | None

    def __init__(
        self,
        name: str,
    ) -> None:
        super().__init__(name)
        self._store = None

    def __call__(self, attribute: T, value: float | Tensor) -> T | None:
        self.update(attribute, value)
        return self.compute()

    def update(self, attribute: T, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        if value > self._max:
            self._max = value
            self._max_at = self._count
            self._store = attribute

    def compute(self) -> T | None:
        return self._store

    def reset(self) -> None:
        self._max = float("-inf")
        self._max_at = 0
        self._count = 0
        self._store = None

    def __repr__(self) -> str:
        return f"{self._name} stored at {self._max_at}"


class MinAttributeStore(AttributeStore[T]):
    _min: float
    _min_at: int
    _count: int
    _store: T | None

    def __init__(
        self,
        name: str,
    ) -> None:
        super().__init__(name)
        self._store = None

    def __call__(self, attribute: T, value: float | Tensor) -> T | None:
        self.update(attribute, value)
        return self.compute()

    def update(self, attribute: T, value: float | Tensor) -> None:
        value = input_to_float(value)
        self._count += 1
        if value < self._min:
            self._min = value
            self._min_at = self._count
            self._store = attribute

    def compute(self) -> T | None:
        return self._store

    def reset(self) -> None:
        self._min = float("inf")
        self._min_at = 0
        self._count = 0
        self._store = None

    def __repr__(self) -> str:
        return f"{self._name} stored at {self._min_at}"
+332 −0

File changed.

Preview size limit exceeded, changes collapsed.