Source code for hydraclick.run
from functools import partial
from pathlib import Path
import logging
import sys
from typing import Callable, Any
import hydra
from omegaconf import DictConfig
from unittest.mock import patch
import flogging
from hydraclick.display_config import display_config
_logger = logging.getLogger(__name__)
[docs]
def get_hydra_configs(
config_file: Path | str,
hydra_args: tuple[str, ...],
) -> list[DictConfig]:
"""Load the necessary configuration for running click commands from a config.yaml file.
Args:
config_file: Path to the target config.yaml file. If None, the default \
config file is loaded.
hydra_args: Arguments passed to hydra for composing the project configuration.
Returns:
configs: list of all the resolved configurations specified by the command arguments.
"""
config_file = Path(config_file)
configs = []
if config_file.exists() and config_file.is_file():
_logger.info("Loading config file from %s", config_file)
hydra_args = ["--config-dir", str(config_file.parent), *list(hydra_args)] # type: ignore
@hydra.main(
config_path=str(config_file.parent), config_name=config_file.stem, version_base=None
)
def load_config(loaded_config: DictConfig):
flogging.setup(allow_trailing_dot=True)
nonlocal configs
configs.append(loaded_config)
with patch("sys.argv", [sys.argv[0], *list(hydra_args)]):
load_config()
else:
_logger.error(f"Invalid config file path provided: {config_file}")
msg = f"Invalid config file: {config_file}"
raise ValueError(msg)
return configs
[docs]
def _run_sequential(
function: Callable[[DictConfig], Any],
configs: list[DictConfig],
num_shards: int = 0,
) -> int:
"""Run the sweep sequentially."""
for config in configs:
if num_shards == 0:
function(config)
continue
if num_shards > 0:
config["num_shards"] = num_shards
num_shards = config.get("num_shards", 1)
for shard_ix in range(num_shards):
config["shard_ix"] = shard_ix
_logger.info("Running shard %d", shard_ix)
function(config)
return 0
[docs]
def _run_parallel(
function: Callable[[DictConfig], Any],
configs: list[DictConfig],
num_shards: int = 0,
) -> int:
"""Run the sweep sequentially."""
try:
import ray # noqa: PLC0415
except ImportError:
_logger.error("Ray is not installed. Please install it with `pip install ray`")
return 1
_conf = configs[0]
ray_opts = _conf.get("ray", {})
ray_opts["ignore_reinit_error"] = True
_logger.info("Launching ray with parameters: %s", ray_opts)
ray.init(**ray_opts)
run_remote = ray.remote(function)
requests = []
_logger.info("Ray launched successfully")
for config in configs:
if num_shards == 0:
req_id = run_remote.remote(config)
requests.append(req_id)
continue
if num_shards > 0:
config["num_shards"] = num_shards
num_shards = config.get("num_shards", 1)
for shard_ix in range(num_shards):
config["shard_ix"] = shard_ix
req_id = run_remote.remote(config)
_logger.info("Running shard %d", shard_ix)
requests.append(req_id)
ray.get(requests)
return 0
[docs]
def run_function(
function: Callable[[DictConfig], Any],
config_file: str | Path | None = None,
hydra_args: tuple[str, ...] | None = None,
multirun: bool = True,
parallel: bool = False,
num_shards: int = 0,
only_config: bool = False,
) -> int:
"""Run the function."""
if multirun:
hydra_args = ["hydra.mode=MULTIRUN", *list(hydra_args)]
configs = get_hydra_configs(config_file, hydra_args)
if only_config:
function = partial(display_config, logger=_logger)
if parallel:
_logger.info("Running in parallel mode")
return _run_parallel(function, configs, num_shards=num_shards)
_logger.info("Running in sequential mode")
return _run_sequential(function, configs, num_shards=num_shards)