在 GitHub 上编辑

TensorFlow

DVCLive 可以为您的 TensorFlow 项目添加实验跟踪功能。

用法

如果您更喜欢使用 Keras API,请查看 DVCLive - Keras 页面。

您需要在每个希望记录指标的地方添加 Live.log_metric() 调用,并在每个训练轮次结束时添加一次 Live.next_step() 调用。

让我们参考从 官方 TensorFlow 指南 中提取的以下示例:

from dvclive import Live

with Live() as live:

    for epoch in range(epochs):
        start_time = time.time()
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                logits = model(x_batch_train, training=True)
                loss_value = loss_fn(y_batch_train, logits)
            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            train_acc_metric.update_state(y_batch_train, logits)

        live.log_metric("train/accuracy", float(train_acc_metric.result())
        train_acc_metric.reset_states()

        for x_batch_val, y_batch_val in val_dataset:
            val_logits = model(x_batch_val, training=False)
            val_acc_metric.update_state(y_batch_val, val_logits)
        live.log_metric("val/accuracy", float(val_acc_metric.result())
        val_acc_metric.reset_states()

        live.next_step()
内容

🐛 发现问题?告诉我们!或者修复它:

在 GitHub 上编辑

有疑问?加入我们的聊天,我们会为您提供帮助:

Discord 聊天