在 GitHub 上编辑
Lightning Fabric
DVCLive 可以为您的 Lightning Fabric 项目添加实验跟踪功能。
用法
您需要创建一个 DVCLiveLogger,然后记录您希望跟踪的参数、指标及其他信息。
以下是一个示例代码片段:
from dvclive.fabric import DVCLiveLogger
from lightning.fabric import Fabric
from lightning.fabric.utilities.rank_zero import rank_zero_only
...
fabric = Fabric()
# Create the DVCLiveLogger
logger = DVCLiveLogger()
# Log dict of hyperparameters
logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})
for epoch in range(epochs):
...
# Log dict of metrics
logger.log_metrics({"loss": loss})
fabric.save("mnist_cnn.pt", model.state_dict())
# Check that `rank_zero_only.rank == 0` to avoid logging in other processes.
if rank_zero_only.rank == 0:
# `logger.experiment` provides access to DVCLive's `Live` instance.
logger.experiment.log_artifact("mnist_cnn.pt")
# Call finalize to save final results as a DVC experiment
logger.finalize("success")初始化
from dvclive.fabric import DVCLiveLogger
logger = DVCLiveLogger()如需自定义跟踪行为,可向 Live 实例传递关键字参数,例如:
logger = DVCLiveLogger(dir="my_directory")记录参数
要记录超参数,请将字典传递给 log_hyperparams():
logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})记录指标
要记录指标,请将字典传递给 log_metrics():
logger.log_metrics({"train_loss": loss})
...
logger.log_metrics({"test_loss": test_loss, "test_acc": test_acc})可选地传递步数(若未传递,将自动递增):
logger.log_metrics({"train_loss": loss}, step=step)额外记录
要记录模型、其他工件或图像,请从记录器中获取 Live 实例:
fabric.save("mnist_cnn.pt", model.state_dict())
if rank_zero_only.rank == 0:
logger.experiment.log_artifact("mnist_cnn.pt")rank_zero_only.rank 确保该操作仅在零号进程上执行。上述其他方法中,DVCLiveLogger 已自动处理此逻辑,但在此示例中我们直接调用了 Live。
logger.experiment 调用 Live 实例,以便您可以调用任意 Live 方法。
结束
最后,结束实验以触发 Live.end():
logger.finalize("success")"success" 作为 status 参数传入,这是必填参数。