在 GitHub 上编辑
Hugging Face Transformers
DVCLive 可以为您的 Hugging Face Transformers 项目添加实验跟踪功能。
如果您正在使用 Hugging Face Accelerate,请查看 DVCLive - Hugging Face Accelerate 页面。
用法
如果您已安装 dvclive
,对于 transformers>=4.36.0
,将自动使用 DVCLiveCallback
来跟踪实验并记录指标、参数和图表。
要记录模型,请在您的环境中设置 HF_DVCLIVE_LOG_MODEL=true
。
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
# optional, `report_to` defaults to "all"
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
如需自定义跟踪,请将 DVCLiveCallback
添加到传递给您的 Trainer
的回调列表中,并配合一个包含额外参数的 Live
实例:
from dvclive import Live
from transformers.integrations import DVCLiveCallback
...
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback(Live(dir="custom_dir")))
trainer.train()
对于 transformers<4.36.0
,请从 dvclive
而非 transformers
导入该回调:
from dvclive.huggingface import DVCLiveCallback
...
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback())
trainer.train()
dvclive.huggingface.DVCLiveCallback
将在 DVCLive 4.0 中被弃用,推荐改用 transformers.integrations.DVCLiveCallback
。
示例
记录模型检查点
使用 HF_DVCLIVE_LOG_MODEL=true
或 log_model=True
保存检查点(内部将使用 Live.log_artifact()
进行保存)。
若为真,DVCLive 会将最后一个检查点的副本保存至 dvclive/artifacts
目录,并用名称 last
或 best
(若启用了 args.load_best_model_at_end)进行标注。
这有助于在 模型注册表 或自动化场景中使用。
- 在训练结束时保存
last
检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
- 在训练结束时保存
best
检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
trainer.args.load_best_model_at_end = True
- 每当保存新检查点时都保存检查点目录的更新:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "all"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
传递额外的 DVCLive 参数
使用 live
传入一个已有的 Live
实例。
from dvclive import Live
from transformers.integrations import DVCLiveCallback
with Live("custom_dir") as live:
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback(live=live))
# Log additional metrics after training
live.log_metric("summary_metric", 1.0, plot=False)
输出格式
每个指标将被记录到:
{Live.plots_dir}/metrics/{split}/{metric}.tsv
其中:
{Live.plots_dir}
在Live
中定义。{split}
可以是train
或eval
。{metric}
是框架提供的名称。