在 GitHub 上编辑
Hugging Face Transformers
DVCLive 可以为您的 Hugging Face Transformers 项目添加实验跟踪功能。
如果您正在使用 Hugging Face Accelerate,请查看 DVCLive - Hugging Face Accelerate 页面。
用法
如果您已安装 dvclive,对于 transformers>=4.36.0,将自动使用 DVCLiveCallback 来跟踪实验并记录指标、参数和图表。
要记录模型,请在您的环境中设置 HF_DVCLIVE_LOG_MODEL=true。
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
# optional, `report_to` defaults to "all"
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)如需自定义跟踪,请将 DVCLiveCallback 添加到传递给您的 Trainer 的回调列表中,并配合一个包含额外参数的 Live 实例:
from dvclive import Live
from transformers.integrations import DVCLiveCallback
...
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback(Live(dir="custom_dir")))
trainer.train()对于 transformers<4.36.0,请从 dvclive 而非 transformers 导入该回调:
from dvclive.huggingface import DVCLiveCallback
...
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback())
trainer.train()dvclive.huggingface.DVCLiveCallback 将在 DVCLive 4.0 中被弃用,推荐改用 transformers.integrations.DVCLiveCallback。
示例
记录模型检查点
使用 HF_DVCLIVE_LOG_MODEL=true 或 log_model=True 保存检查点(内部将使用 Live.log_artifact() 进行保存)。
若为真,DVCLive 会将最后一个检查点的副本保存至 dvclive/artifacts 目录,并用名称 last 或 best(若启用了 args.load_best_model_at_end)进行标注。
这有助于在 模型注册表 或自动化场景中使用。
- 在训练结束时保存
last检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)- 在训练结束时保存
best检查点:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "true"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)
trainer.args.load_best_model_at_end = True- 每当保存新检查点时都保存检查点目录的更新:
os.environ["HF_DVCLIVE_LOG_MODEL"] = "all"
from transformers import TrainingArguments, Trainer
args = TrainingArguments(..., report_to="dvclive")
trainer = Trainer(..., args=args)传递额外的 DVCLive 参数
使用 live 传入一个已有的 Live 实例。
from dvclive import Live
from transformers.integrations import DVCLiveCallback
with Live("custom_dir") as live:
trainer = Trainer(...)
trainer.add_callback(DVCLiveCallback(live=live))
# Log additional metrics after training
live.log_metric("summary_metric", 1.0, plot=False)输出格式
每个指标将被记录到:
{Live.plots_dir}/metrics/{split}/{metric}.tsv其中:
{Live.plots_dir}在Live中定义。{split}可以是train或eval。{metric}是框架提供的名称。