在 GitHub 上编辑

Lightning Fabric

DVCLive 可以为您的 Lightning Fabric 项目添加实验跟踪功能。

用法

您需要创建一个 DVCLiveLogger,然后记录您希望跟踪的参数、指标及其他信息。

以下是一个示例代码片段:

from dvclive.fabric import DVCLiveLogger
from lightning.fabric import Fabric
from lightning.fabric.utilities.rank_zero import rank_zero_only

...

fabric = Fabric()

# Create the DVCLiveLogger
logger = DVCLiveLogger()

# Log dict of hyperparameters
logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})

for epoch in range(epochs):

    ...

    # Log dict of metrics
    logger.log_metrics({"loss": loss})

fabric.save("mnist_cnn.pt", model.state_dict())
# Check that `rank_zero_only.rank == 0` to avoid logging in other processes.
if rank_zero_only.rank == 0:
    # `logger.experiment` provides access to DVCLive's `Live` instance.
    logger.experiment.log_artifact("mnist_cnn.pt")

# Call finalize to save final results as a DVC experiment
logger.finalize("success")

初始化

from dvclive.fabric import DVCLiveLogger

logger = DVCLiveLogger()

如需自定义跟踪行为,可向 Live 实例传递关键字参数,例如:

logger = DVCLiveLogger(dir="my_directory")

记录参数

要记录超参数,请将字典传递给 log_hyperparams()

logger.log_hyperparams({"batch_size": 64, "epochs": 5, "lr": 1.0, ...})

记录指标

要记录指标,请将字典传递给 log_metrics()

logger.log_metrics({"train_loss": loss})
...
logger.log_metrics({"test_loss": test_loss, "test_acc": test_acc})

可选地传递步数(若未传递,将自动递增):

logger.log_metrics({"train_loss": loss}, step=step)

额外记录

要记录模型、其他工件或图像,请从记录器中获取 Live 实例:

fabric.save("mnist_cnn.pt", model.state_dict())
if rank_zero_only.rank == 0:
    logger.experiment.log_artifact("mnist_cnn.pt")

rank_zero_only.rank 确保该操作仅在零号进程上执行。上述其他方法中,DVCLiveLogger 已自动处理此逻辑,但在此示例中我们直接调用了 Live

logger.experiment 调用 Live 实例,以便您可以调用任意 Live 方法。

结束

最后,结束实验以触发 Live.end()

logger.finalize("success")

"success" 作为 status 参数传入,这是必填参数。

内容

🐛 发现问题?告诉我们!或者修复它:

在 GitHub 上编辑

有疑问?加入我们的聊天,我们会为您提供帮助:

Discord 聊天