Loading regstar_library/train.py 0 → 100644 +52 −0 Original line number Diff line number Diff line # Standard Imports import random # Third-party imports import hydra from omegaconf import DictConfig, OmegaConf # Local imports from regstar_library.launch.experiment.experiment import Experiment, RayExperiment from regstar_library.logging.logging import configure_logging OmegaConf.register_new_resolver( "common_seed", lambda: random.randint(0, 2**31), use_cache=True, replace=True ) @hydra.main(version_base=None, config_path="../conf", config_name="default") def train(cfg: DictConfig) -> None: configure_logging(cfg.logging) experiment_logger = hydra.utils.instantiate( config=cfg.experiment.logger, _recursive_=True, _convert_="all" ) metrics = hydra.utils.instantiate( config=cfg.experiment.metrics, _recursive_=True, _convert_="all" ) if cfg.get("use_ray", False): experiment = RayExperiment( logger=experiment_logger, metrics=metrics, run=cfg.experiment.run, seed=cfg.experiment.seed, epochs=cfg.experiment.epochs, ray_kwargs=OmegaConf.to_container(cfg.ray_kwargs, resolve=True), ) else: experiment = Experiment( logger=experiment_logger, metrics=metrics, run=cfg.experiment.run, epochs=cfg.experiment.epochs, seed=cfg.experiment.seed, ) experiment.setup() experiment.log_config(cfg.experiment) experiment.execute() experiment.teardown() if __name__ == "__main__": train() Loading
regstar_library/train.py 0 → 100644 +52 −0 Original line number Diff line number Diff line # Standard Imports import random # Third-party imports import hydra from omegaconf import DictConfig, OmegaConf # Local imports from regstar_library.launch.experiment.experiment import Experiment, RayExperiment from regstar_library.logging.logging import configure_logging OmegaConf.register_new_resolver( "common_seed", lambda: random.randint(0, 2**31), use_cache=True, replace=True ) @hydra.main(version_base=None, config_path="../conf", config_name="default") def train(cfg: DictConfig) -> None: configure_logging(cfg.logging) experiment_logger = hydra.utils.instantiate( config=cfg.experiment.logger, _recursive_=True, _convert_="all" ) metrics = hydra.utils.instantiate( config=cfg.experiment.metrics, _recursive_=True, _convert_="all" ) if cfg.get("use_ray", False): experiment = RayExperiment( logger=experiment_logger, metrics=metrics, run=cfg.experiment.run, seed=cfg.experiment.seed, epochs=cfg.experiment.epochs, ray_kwargs=OmegaConf.to_container(cfg.ray_kwargs, resolve=True), ) else: experiment = Experiment( logger=experiment_logger, metrics=metrics, run=cfg.experiment.run, epochs=cfg.experiment.epochs, seed=cfg.experiment.seed, ) experiment.setup() experiment.log_config(cfg.experiment) experiment.execute() experiment.teardown() if __name__ == "__main__": train()