Trainer

用于模型训练。

class mindnlp.engine.trainer.Trainer(network=None, train_dataset=None, eval_dataset=None, metrics=None, epochs=10, loss_fn=None, optimizer=None, callbacks=None, jit=False)[源代码]

基类:object

用于模型训练的训练器。

参数
  • network (Cell) – 用于训练的模型网络。

  • train_dataset (Dataset) – 训练数据集迭代器。如果定义了 loss_fn,数据和标签将分别传递给 networkloss_fn,所以应该从数据集中返回一个元组(数据,标签)。如果有多个数据或标签,将`loss_fn`设置为None,并在`network`中实现loss的计算,然后将一个元组(data1,data2,data3,…)从数据集返回的所有数据传递给` 网络`。

  • eval_dataset (Dataset) – 测试数据集迭代器。如果定义了 loss_fn,数据和标签将分别传递给 networkloss_fn,所以应该从数据集中返回一个元组(数据,标签)。如果有多个数据或标签,将`loss_fn`设置为None,并在`network`中实现loss的计算,然后将一个元组(data1,data2,data3,…)从数据集返回的所有数据传递给` 网络`。

  • metrics (Optional[list[Metrics], Metrics]) – 评估时应使用的评测指标对象列表。 默认值:无。

  • epochs (int) – 数据的总迭代次数。 默认值:10。

  • optimizer (Cell) – 用于更新权重的优化器。如果 optimizer 为 None,则 network 需要进行反向传播和更新权重。 默认值:无。

  • loss_fn (Cell) – 目标函数。 如果 loss_fn 为 None,则 network 应包含损失计算和并行计算(如果需要)。 默认值:无。

  • callbacks (Optional[list[Callback], Callback]) – 在模型训练过程中应当执行的回调函数方法列表。默认值:无。

  • jit (bool) – 是否使用Just-In-Time编译模式。

run(tgt_columns=None)[源代码]

训练过程入口。

参数

tgt_columns (Optional[list[str], str]) – 用于计算损失函数的目标标签列名称。