Commit 9affd911 authored by xdrazkov's avatar xdrazkov
Browse files

feat: implement training script with Hydra configuration

parent eac3053b
Loading
Loading
Loading
Loading
+52 −0
Original line number Diff line number Diff line
# Standard Imports
import random

# Third-party imports
import hydra
from omegaconf import DictConfig, OmegaConf

# Local imports
from regstar_library.launch.experiment.experiment import Experiment, RayExperiment
from regstar_library.logging.logging import configure_logging


OmegaConf.register_new_resolver(
    "common_seed", lambda: random.randint(0, 2**31), use_cache=True, replace=True
)


@hydra.main(version_base=None, config_path="../conf", config_name="default")
def train(cfg: DictConfig) -> None:
    configure_logging(cfg.logging)

    experiment_logger = hydra.utils.instantiate(
        config=cfg.experiment.logger, _recursive_=True, _convert_="all"
    )
    metrics = hydra.utils.instantiate(
        config=cfg.experiment.metrics, _recursive_=True, _convert_="all"
    )
    if cfg.get("use_ray", False):
        experiment = RayExperiment(
            logger=experiment_logger,
            metrics=metrics,
            run=cfg.experiment.run,
            seed=cfg.experiment.seed,
            epochs=cfg.experiment.epochs,
            ray_kwargs=OmegaConf.to_container(cfg.ray_kwargs, resolve=True),
        )
    else:
        experiment = Experiment(
            logger=experiment_logger,
            metrics=metrics,
            run=cfg.experiment.run,
            epochs=cfg.experiment.epochs,
            seed=cfg.experiment.seed,
        )
    experiment.setup()
    experiment.log_config(cfg.experiment)
    experiment.execute()
    experiment.teardown()


if __name__ == "__main__":
    train()