在 GitHub 上编辑

PyTorch Lightning

DVCLive 可以为您的 PyTorch Lightning 项目添加实验跟踪功能。

如果您正在使用 Lightning Fabric,请查看 DVCLive - Lightning Fabric 页面。

用法

如果您将 DVCLiveLogger 传递给您的 Trainer,DVCLive 将自动记录您在 LightningModule 中跟踪的 指标参数

import lightning.pytorch as pl
from dvclive.lightning import DVCLiveLogger

...
class LitModule(pl.LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2):
        super().__init__()
        # layer_1_dim and learning_rate will be logged by DVCLive
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        metric = ...
        # See Output Format bellow
        self.log("train_metric", metric, on_step=False, on_epoch=True)

dvclive_logger = DVCLiveLogger()

model = LitModule()
trainer = pl.Trainer(logger=dvclive_logger)
trainer.fit(model)

默认情况下,PyTorch Lightning 会使用日志器的名称(DVCLiveLogger)创建一个目录来存储检查点。您可以按照 PyTorch Lightning 文档 中的说明更改检查点路径或完全禁用检查点功能。

参数

  • run_name -(默认为 None)- 运行名称,用于 PyTorch Lightning 获取版本号。

  • prefix -(默认为 None)- 添加到每个指标名称前的字符串。

  • log_model -(默认为 False)- 使用 live.log_artifact() 记录由 ModelCheckpoint 创建的检查点。参见 记录模型检查点

    • 如果 log_model == False(默认值),则不会记录任何检查点。

    • 如果 log_model == True,则在训练结束时记录检查点,除非 save_top_k == -1,此时会在训练过程中记录所有检查点。

    • 如果 log_model == 'all',则在训练过程中记录所有检查点。

  • experiment -(默认为 None)- 使用的 Live 对象,而非初始化一个新的对象。

  • **kwargs - 任何其他参数都将用于实例化一个新的 Live 实例。如果使用了 experiment,这些参数将被忽略。

示例

记录模型检查点

使用 log_model 保存检查点(内部将使用 Live.log_artifact() 进行保存)。训练结束时,DVCLive 会将 best_model_path 复制到 dvclive/artifacts 目录,并将其标注为名称 best(例如,供 DVC Studio 模型注册表 或自动化场景使用)。

  • 在训练结束时保存检查点目录的更新:
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model=True)
trainer = Trainer(logger=logger)
trainer.fit(model)
  • 每当保存新检查点时都保存检查点目录的更新:
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model="all")
trainer = Trainer(logger=logger)
trainer.fit(model)
  • 使用自定义的 ModelCheckpoint
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model=True),
checkpoint_callback = ModelCheckpoint(
        dirpath="model",
        monitor="val_acc",
        mode="max",
)
trainer = Trainer(logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)

传递额外的 DVCLive 参数

  • 使用 experiment 传入已有的 Live 实例。
from dvclive import Live
from dvclive.lightning import DVCLiveLogger

with Live("custom_dir") as live:
    trainer = Trainer(
        logger=DVCLiveLogger(experiment=live))
    trainer.fit(model)
    # Log additional metrics after training
    live.log_metric("summary_metric", 1.0, plot=False)
  • 使用 **kwargs 自定义 Live
from dvclive.lightning import DVCLiveLogger

trainer = Trainer(
    logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)

输出格式

每个指标将被记录到:

{Live.plots_dir}/metrics/{split_prefix}/{iter_type}/{metric_name}.tsv

其中:

  • {Live.plots_dir}Live 中定义。
  • {iter_type} 可以是 epochstep,由 log 调用中使用的 on_stepon_epoch 参数推断得出。
  • {split_prefix}_{metric_name} 是传递给 log 调用的完整字符串。split_prefix 可以是 trainvaltest

在上述示例中,记录的指标为:

self.log("train_metric", metric, on_step=False, on_epoch=True)

将被存储在:

dvclive/metrics/train/epoch/metric.tsv
内容

🐛 发现问题?告诉我们!或者修复它:

在 GitHub 上编辑

有疑问?加入我们的聊天,我们会为您提供帮助:

Discord 聊天