在 GitHub 上编辑

PyTorch

DVCLive 可以为您的 PyTorch 项目添加实验跟踪功能。

如果您正在使用 PyTorch Lightning,请查看 DVCLive - PyTorch Lightning 页面。

用法

您需要创建一个 Live 实例,并调用 记录数据更新步数 的方法。

以下代码片段用于上述链接的 Colab 笔记本中:

from dvclive import Live

...

with Live(report="notebook") as live:

    live.log_params(params)

    for _ in range(params["epochs"]):

        train_one_epoch(
            model, criterion, x_train, y_train, params["lr"], params["weight_decay"]
        )

        # Train Evaluation
        metrics_train, acual_train, predicted_train = evaluate(
            model, x_train, y_train)

        for k, v in metrics_train.items():
            live.log_metric(f"train/{k}", v)

        live.log_sklearn_plot(
            "confusion_matrix",
            acual_train, predicted_train,
            name="train/confusion_matrix"
        )

        # Test Evaluation
        metrics_test, actual, predicted = evaluate(
            model, x_test, y_test)

        for k, v in metrics_test.items():
            live.log_metric(f"test/{k}", v)

        live.log_sklearn_plot(
            "confusion_matrix", actual, predicted, name="test/confusion_matrix"
        )

        live.log_image(
            "misclassified.jpg",
            get_missclassified_image(actual, predicted, mnist_test)
        )

        # Save best model
        if metrics_test["acc"] > best_test_acc:
            torch.save(model.state_dict(), "model.pt")

        live.next_step()

    live.log_artifact("model.pt", type="model", name="pytorch-model")

DistributedDataParallel

如果您使用 DistributedDataParallel(DDP)在多个进程中并行训练,请仅在 rank 0 进程中调用 DVCLive。Lightning 回调 会自动处理这一点。您也可以自行编写代码,确保仅在 rank 0 进程中调用 DVCLive:

from dvclive import Live
from torch.distributed import get_rank

...

rank = torch.distributed.get_rank()

if rank == 0:
    # Train model and log with dvclive
    with Live() as live:
        train(...)
        live.log_metric(...)

else:
    # Train model without dvclive
    train(...)
内容

🐛 发现问题?告诉我们!或者修复它:

在 GitHub 上编辑

有疑问?加入我们的聊天,我们会为您提供帮助:

Discord 聊天