PyTorch Lightning
DVCLive 可以为您的 PyTorch Lightning 项目添加实验跟踪功能。
如果您正在使用 Lightning Fabric,请查看 DVCLive - Lightning Fabric 页面。
用法
如果您将 DVCLiveLogger
传递给您的 Trainer
,DVCLive 将自动记录您在 LightningModule
中跟踪的 指标 和 参数。
import lightning.pytorch as pl
from dvclive.lightning import DVCLiveLogger
...
class LitModule(pl.LightningModule):
def __init__(self, layer_1_dim=128, learning_rate=1e-2):
super().__init__()
# layer_1_dim and learning_rate will be logged by DVCLive
self.save_hyperparameters()
def training_step(self, batch, batch_idx):
metric = ...
# See Output Format bellow
self.log("train_metric", metric, on_step=False, on_epoch=True)
dvclive_logger = DVCLiveLogger()
model = LitModule()
trainer = pl.Trainer(logger=dvclive_logger)
trainer.fit(model)
默认情况下,PyTorch Lightning 会使用日志器的名称(DVCLiveLogger
)创建一个目录来存储检查点。您可以按照 PyTorch Lightning 文档 中的说明更改检查点路径或完全禁用检查点功能。
参数
-
run_name
-(默认为None
)- 运行名称,用于 PyTorch Lightning 获取版本号。 -
prefix
-(默认为None
)- 添加到每个指标名称前的字符串。 -
log_model
-(默认为False
)- 使用live.log_artifact()
记录由ModelCheckpoint
创建的检查点。参见 记录模型检查点。-
如果
log_model == False
(默认值),则不会记录任何检查点。 -
如果
log_model == True
,则在训练结束时记录检查点,除非save_top_k == -1
,此时会在训练过程中记录所有检查点。 -
如果
log_model == 'all'
,则在训练过程中记录所有检查点。
-
-
experiment
-(默认为None
)- 使用的Live
对象,而非初始化一个新的对象。 -
**kwargs
- 任何其他参数都将用于实例化一个新的Live
实例。如果使用了experiment
,这些参数将被忽略。
示例
记录模型检查点
使用 log_model
保存检查点(内部将使用 Live.log_artifact()
进行保存)。训练结束时,DVCLive 会将 best_model_path
复制到 dvclive/artifacts
目录,并将其标注为名称 best
(例如,供 DVC Studio 模型注册表 或自动化场景使用)。
- 在训练结束时保存检查点目录的更新:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model=True)
trainer = Trainer(logger=logger)
trainer.fit(model)
- 每当保存新检查点时都保存检查点目录的更新:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model="all")
trainer = Trainer(logger=logger)
trainer.fit(model)
- 使用自定义的
ModelCheckpoint
:
from dvclive.lightning import DVCLiveLogger
logger = DVCLiveLogger(log_model=True),
checkpoint_callback = ModelCheckpoint(
dirpath="model",
monitor="val_acc",
mode="max",
)
trainer = Trainer(logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)
传递额外的 DVCLive 参数
- 使用
experiment
传入已有的Live
实例。
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
with Live("custom_dir") as live:
trainer = Trainer(
logger=DVCLiveLogger(experiment=live))
trainer.fit(model)
# Log additional metrics after training
live.log_metric("summary_metric", 1.0, plot=False)
- 使用
**kwargs
自定义Live
。
from dvclive.lightning import DVCLiveLogger
trainer = Trainer(
logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)
输出格式
每个指标将被记录到:
{Live.plots_dir}/metrics/{split_prefix}/{iter_type}/{metric_name}.tsv
其中:
{Live.plots_dir}
在Live
中定义。{iter_type}
可以是epoch
或step
,由log
调用中使用的on_step
和on_epoch
参数推断得出。{split_prefix}_{metric_name}
是传递给log
调用的完整字符串。split_prefix
可以是train
、val
或test
。
在上述示例中,记录的指标为:
self.log("train_metric", metric, on_step=False, on_epoch=True)
将被存储在:
dvclive/metrics/train/epoch/metric.tsv