在 GitHub 上编辑
Lightning Fabric
DVCLive 可以为您的 Lightning Fabric 项目添加实验跟踪功能。
用法
您需要创建一个 DVCLiveLogger
,然后记录您希望跟踪的参数、指标及其他信息。
以下是一个示例代码片段:
from dvclive.fabric import DVCLiveLogger
from lightning.fabric import Fabric
from lightning.fabric.utilities.rank_zero import rank_zero_only
...
fabric = Fabric()
# Create the DVCLiveLogger
logger = DVCLiveLogger()
# Log dict of hyperparameters
logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})
for epoch in range(epochs):
...
# Log dict of metrics
logger.log_metrics({"loss": loss})
fabric.save("mnist_cnn.pt", model.state_dict())
# Check that `rank_zero_only.rank == 0` to avoid logging in other processes.
if rank_zero_only.rank == 0:
# `logger.experiment` provides access to DVCLive's `Live` instance.
logger.experiment.log_artifact("mnist_cnn.pt")
# Call finalize to save final results as a DVC experiment
logger.finalize("success")
初始化
from dvclive.fabric import DVCLiveLogger
logger = DVCLiveLogger()
如需自定义跟踪行为,可向 Live
实例传递关键字参数,例如:
logger = DVCLiveLogger(dir="my_directory")
记录参数
要记录超参数,请将字典传递给 log_hyperparams()
:
logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})
记录指标
要记录指标,请将字典传递给 log_metrics()
:
logger.log_metrics({"train_loss": loss})
...
logger.log_metrics({"test_loss": test_loss, "test_acc": test_acc})
可选地传递步数(若未传递,将自动递增):
logger.log_metrics({"train_loss": loss}, step=step)
额外记录
要记录模型、其他工件或图像,请从记录器中获取 Live
实例:
fabric.save("mnist_cnn.pt", model.state_dict())
if rank_zero_only.rank == 0:
logger.experiment.log_artifact("mnist_cnn.pt")
rank_zero_only.rank
确保该操作仅在零号进程上执行。上述其他方法中,DVCLiveLogger
已自动处理此逻辑,但在此示例中我们直接调用了 Live
。
logger.experiment
调用 Live
实例,以便您可以调用任意 Live
方法。
结束
最后,结束实验以触发 Live.end()
:
logger.finalize("success")
"success"
作为 status
参数传入,这是必填参数。