在 GitHub 上编辑

Hugging Face Transformers

DVCLive 可以为您的 Hugging Face Transformers 项目添加实验跟踪功能。

如果您正在使用 Hugging Face Accelerate,请查看 DVCLive - Hugging Face Accelerate 页面。

用法

如果您已安装 dvclive,对于 transformers>=4.36.0,将自动使用 DVCLiveCallback 来跟踪实验并记录指标、参数和图表。

要记录模型,请在您的环境中设置 HF_DVCLIVE_LOG_MODEL=true

os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"

from transformers import TrainingArguments, Trainer

# optional, `report_to` defaults to "all"
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)

如需自定义跟踪,请将 DVCLiveCallback 添加到传递给您的 Trainer 的回调列表中,并配合一个包含额外参数的 Live 实例:

from dvclive import Live
from transformers.integrations import DVCLiveCallback

...

trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback(Live(dir="custom_dir")))
trainer.train()

对于 transformers<4.36.0,请从 dvclive 而非 transformers 导入该回调:

from dvclive.huggingface import DVCLiveCallback

...

trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback())
trainer.train()

dvclive.huggingface.DVCLiveCallback 将在 DVCLive 4.0 中被弃用,推荐改用 transformers.integrations.DVCLiveCallback

示例

记录模型检查点

使用 HF_DVCLIVE_LOG_MODEL=truelog_model=True 保存检查点(内部将使用 Live.log_artifact() 进行保存)。

若为真,DVCLive 会将最后一个检查点的副本保存至 dvclive/artifacts 目录,并用名称 lastbest(若启用了 args.load_best_model_at_end)进行标注。

这有助于在 模型注册表 或自动化场景中使用。

  • 在训练结束时保存 last 检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"

from transformers import TrainingArguments, Trainer

args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
  • 在训练结束时保存 best 检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"

from transformers import TrainingArguments, Trainer

args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
trainer.args.load_best_model_at_end = True
  • 每当保存新检查点时都保存检查点目录的更新:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "all"

from transformers import TrainingArguments, Trainer

args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)

传递额外的 DVCLive 参数

使用 live 传入一个已有的 Live 实例。

from dvclive import Live
from transformers.integrations import DVCLiveCallback

with Live("custom_dir") as live:
    trainer = Trainer(...)
    trainer.add_callback(DVCLiveCallback(live=live))

    # Log additional metrics after training
    live.log_metric("summary_metric", 1.0, plot=False)

输出格式

每个指标将被记录到:

{Live.plots_dir}/metrics/{split}/{metric}.tsv

其中:

  • {Live.plots_dir}Live 中定义。
  • {split} 可以是 traineval
  • {metric} 是框架提供的名称。
内容

🐛 发现问题?告诉我们!或者修复它:

在 GitHub 上编辑

有疑问?加入我们的聊天,我们会为您提供帮助:

Discord 聊天