Commit eac3053b authored by xdrazkov's avatar xdrazkov
Browse files

feat: add experiment and run

parent f9b966cc
Loading
Loading
Loading
Loading
+194 −0
Original line number Diff line number Diff line
# Standard Imports
import socket
from abc import ABC, abstractmethod

# Third-party Imports
import git
import hydra.utils
import omegaconf
import torch
from omegaconf import OmegaConf

# Local Imports
from regstar_library.logging.logging import get_logger
from regstar_library.launch.experiment.loggers.abstract import ExperimentLogger
from regstar_library.launch.run.run import Run, RunResults
from regstar_library.metrics.abstract import Metric
from regstar_library.metrics.aggregation import MeanMetric


log = get_logger("reporting")

OmegaConf.register_new_resolver(
    "hostname", lambda: socket.gethostname(), use_cache=False, replace=True
)


def set_seed(seed: int):
    import random
    import numpy
    import torch

    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    numpy.random.seed(seed)


class AbstractExperiment(ABC):
    _logger: ExperimentLogger
    epochs: int
    seed: int
    metrics: dict[str, Metric]
    times: dict[str, Metric]

    def __init__(
        self,
        logger: ExperimentLogger,
        run: omegaconf.DictConfig,
        epochs: int,
        seed: int,
        metrics: dict[str, Metric | list[Metric]] | None = None,
    ) -> None:
        self._logger = logger
        OmegaConf.resolve(run)
        self.run_cfg = run
        self.epochs = epochs
        self.seed = seed
        self.metrics = metrics or {}
        self.times = {
            "train": MeanMetric("train"),
            "eval": MeanMetric("eval"),
            "avg_forward": MeanMetric("forward"),
            "avg_backward": MeanMetric("backward"),
            "avg_optimizer": MeanMetric("optimizer"),
            "avg_train_step": MeanMetric("train_step"),
            "avg_eval_step": MeanMetric("eval_step"),
        }

    @abstractmethod
    def execute(self) -> None: ...

    def setup(self) -> None:
        self.log_git_status()

    def teardown(self) -> None:
        self._logger.close()
        self.print_stats()

    def _update_epoch(self, data: RunResults) -> None:
        self._logger.log_run(data.get("run_id"))
        self._update_metrics(data.get("metrics", {}))
        self._update_times(data.get("times", {}))

    def _update_metrics(self, metrics_data: dict[str, float]) -> None:
        for name, value in metrics_data.items():
            self._logger.log_metric(name, value)

        for key, metrics in self.metrics.items():
            value = metrics_data.get(key)
            if value is None:
                continue
            if isinstance(metrics, Metric):
                metrics = [metrics]
            for metric in metrics:
                self._logger.log_metrics_stats(metric(value))

    def _update_times(self, times: dict[str, float]) -> None:
        for key, metric in self.times.items():
            value = times.get(key)
            if value is not None:
                self._logger.log_times_stats(metric(value))

    def print_stats(self) -> None:
        msg = "-------------"
        msg += "\nMetrics:"
        for metrics in self.metrics.values():
            if isinstance(metrics, Metric):
                metrics = [metrics]
            for metric in metrics:
                msg += "\t" + str(metric)

        msg += "\nTimes:"
        for time in self.times.values():
            msg += "\t" + str(time)

        log.info(msg)

    @staticmethod
    def _print_epoch_stats(data: RunResults) -> None:
        msg = f"Epoch={data['epoch']}: "
        for name, metric in data["metrics"].items():
            if metric is not None:
                msg += f"\t{name}={metric:.3f}"

        msg += "\tTime: "
        for name, time in data["times"].items():
            if (
                name in ["train", "avg_train_step", "avg_eval_step"]
                and time is not None
            ):
                msg += f"\t{name}={time:.3f}"

        log.info(msg)

    def log_config(self, config: omegaconf.DictConfig) -> None:
        self._logger.log_meta("settings", OmegaConf.to_container(config, resolve=True))

    def log_git_status(self) -> None:
        path_to_git_repo = hydra.utils.get_original_cwd()
        try:
            repo = git.Repo(path_to_git_repo)
        except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError) as err:
            log.warning(f"Cannot get git repo: {err}")
            return

        try:
            if not repo.remotes:  # if empty
                raise StopIteration

            remote_url = next(repo.remotes[0].urls)
            self._logger.log_meta("git.remote_url", remote_url)
        except StopIteration:
            log.warning("Cannot get git remote url")

        if not repo.head.is_detached:
            branch = repo.active_branch.name
            self._logger.log_meta("git.branch", branch)
        else:
            log.warning("Cannot get git branch ('detached HEAD' state)")

        commit_hash = repo.head.commit.hexsha
        self._logger.log_meta("git.commit", commit_hash[:7])


def execute_run(
    run_cfg: omegaconf.DictConfig,
    experiment_id: str,
    epoch: int,
    seed: int,
) -> RunResults:
    torch.set_default_dtype(torch.float64)
    set_seed(seed + epoch)
    run: Run = hydra.utils.instantiate(
        config=run_cfg,
        epoch=epoch,
        seed=seed,
        experiment_id=experiment_id,
        _convert_="all",
    )
    return run.execute()


class Experiment(AbstractExperiment):
    def execute(self) -> None:
        assert self._logger.document_id is not None
        for epoch in range(1, self.epochs + 1):
            data = execute_run(
                run_cfg=self.run_cfg,
                experiment_id=self._logger.document_id,
                epoch=epoch,
                seed=self.seed,
            )
            self._print_epoch_stats(data)
            self._update_epoch(data)
+48 −0
Original line number Diff line number Diff line
# Standard Imports
from abc import ABC, abstractmethod
from typing import Any

# Local Imports
from regstar_library.metrics.abstract import MetricValueDict


class ExperimentLogger(ABC):
    @abstractmethod
    def __init__(
        self, *args: Any, metadata: dict[str, Any] | None = None, **kwargs: Any
    ) -> None: ...

    def log_meta_dict(self, metadata: dict[str, Any]) -> None:
        for name, value in metadata.items():
            self.log_meta(name, value)

    @abstractmethod
    def log_meta(self, name: str, value: Any) -> None:
        """Logs a metadata. Overwrites the previous data."""
        ...

    @abstractmethod
    def log_metric(self, name: str, value: float) -> None:
        """Appends a value to an array of values for a metric."""
        ...

    @abstractmethod
    def log_metrics_stats(self, value_dict: MetricValueDict) -> None:
        """Logs the statistics of a metric. Overwrites the previous statistics."""
        ...

    @abstractmethod
    def log_times_stats(self, value_dict: MetricValueDict) -> None: ...

    @abstractmethod
    def log_run(self, run_id: str) -> None: ...

    @abstractmethod
    def open(self) -> None: ...

    @abstractmethod
    def close(self) -> None: ...

    @property
    @abstractmethod
    def document_id(self) -> str | None: ...
+62 −0
Original line number Diff line number Diff line
# Standard Imports
from typing import Any

# Third-party Imports
import omegaconf
import ray

# Local Imports
from regstar_library.launch.experiment.loggers.abstract import ExperimentLogger
from regstar_library.metrics.abstract import Metric
from regstar_library.launch.experiment.experiment import AbstractExperiment, execute_run

remote_execute_run = ray.remote(execute_run)


class RayExperiment(AbstractExperiment):
    ray_kwargs: dict[str, Any]

    def __init__(
        self,
        logger: ExperimentLogger,
        run: omegaconf.DictConfig,
        epochs: int,
        seed: int,
        ray_kwargs: dict[str, Any] | None = None,
        metrics: dict[str, Metric] | None = None,
    ) -> None:
        super().__init__(
            logger=logger,
            run=run,
            epochs=epochs,
            seed=seed,
            metrics=metrics,
        )
        self.ray_kwargs = ray_kwargs or {}

    def execute(self) -> None:
        ray_shutdown = False
        if not ray.is_initialized():
            ray_shutdown = True
            ray.init(**self.ray_kwargs)

        kwargs = {
            "run_cfg": self.run_cfg,
            "experiment_id": self._logger.document_id,
            "seed": self.seed,
        }
        put_kwargs = {key: ray.put(arg) for key, arg in kwargs.items()}

        futures = [
            remote_execute_run.remote(epoch=epoch, **put_kwargs)  # type: ignore
            for epoch in range(1, self.epochs + 1)
        ]

        while futures:
            finished, futures = ray.wait(futures, num_returns=1)
            data = ray.get(finished[0])
            self._print_epoch_stats(data)
            self._update_epoch(data)

        if ray_shutdown:
            ray.shutdown()
+58 −0
Original line number Diff line number Diff line
# Standard Imports
from abc import ABC, abstractmethod

# Third-party Imports
import hydra.utils
import omegaconf

# Local Imports
from regstar_library.modules.abstract import AbstractModule
from regstar_library.trainer.trainer import Results, StrategyTrainer


class RunResults(Results):
    epoch: int


class AbstractRun(ABC):
    @abstractmethod
    def __init__(
        self,
        epoch: int,
        experiment_id: str,
    ) -> None:
        """Must accept the epoch and experiment_id as arguments."""
        ...

    @abstractmethod
    def execute(self) -> RunResults: ...


class Run(AbstractRun):
    module: AbstractModule
    trainer: StrategyTrainer[AbstractModule]
    epoch: int
    seed: int  # the seed that was set for this run before initialization

    def __init__(
        self,
        module: AbstractModule,
        trainer: omegaconf.DictConfig,
        seed: int,
        epoch: int,
        experiment_id: str,
    ) -> None:
        self.epoch = epoch
        self.module = module
        self.trainer = hydra.utils.instantiate(
            config=trainer,
            _target_=StrategyTrainer,
            module=self.module,
        )
        module.log_meta("seed", seed)
        module.log_meta("epoch", epoch)
        module.log_meta("experiment", experiment_id)

    def execute(self) -> RunResults:
        results = self.trainer.train()
        return RunResults(epoch=self.epoch, **results)