Source code for deepspeed.monitor.config

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.pydantic_v1 import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


def get_monitor_config(param_dict):
    monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")}
    return DeepSpeedMonitorConfig(**monitor_dict)


[docs]class TensorBoardConfig(DeepSpeedConfigModel): """Sets parameters for TensorBoard monitor.""" enabled: bool = False """ Whether logging to Tensorboard is enabled. Requires `tensorboard` package is installed. """ output_path: str = "" """ Path to where the Tensorboard logs will be written. If not provided, the output path is set under the training script’s launching path. """ job_name: str = "DeepSpeedJobName" """ Name for the current job. This will become a new directory inside `output_path`. """
[docs]class WandbConfig(DeepSpeedConfigModel): """Sets parameters for WandB monitor.""" enabled: bool = False """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """ group: str = None """ Name for the WandB group. This can be used to group together runs. """ team: str = None """ Name for the WandB team. """ project: str = "deepspeed" """ Name for the WandB project. """
[docs]class CSVConfig(DeepSpeedConfigModel): """Sets parameters for CSV monitor.""" enabled: bool = False """ Whether logging to local CSV files is enabled. """ output_path: str = "" """ Path to where the csv files will be written. If not provided, the output path is set under the training script’s launching path. """ job_name: str = "DeepSpeedJobName" """ Name for the current job. This will become a new directory inside `output_path`. """
class DeepSpeedMonitorConfig(DeepSpeedConfigModel): """Sets parameters for various monitoring methods.""" tensorboard: TensorBoardConfig = {} """ TensorBoard monitor, requires `tensorboard` package is installed. """ wandb: WandbConfig = {} """ WandB monitor, requires `wandb` package is installed. """ csv_monitor: CSVConfig = {} """ Local CSV output of monitoring data. """ @root_validator def check_enabled(cls, values): values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( "csv_monitor").enabled return values