使用Callback自定义训练过程

Callback是与训练器紧密相关的模块,在Trainer中使用回调函数能够实现计时、早停、保存checkpoint等在模型训练过程中所需要的额外操作。同时,MindNLP还支持自定义callabck。

在Engine中使用Callback

Callback需要在已经定义训练器Trainer或评测器Evaluator的前提下使用。MindNLP同时支持向Engine传入两种类型的callback参数:Callback类型和list[Callback]类型。Engine会自动执行所传入的callback对应的功能。

在Engine中使用Callback的代码如下所示:

import mindspore.dataset as ds

from mindspore import nn

from mindnlp.engine.trainer import Trainer
from mindnlp.engine.callbacks.earlystop_callback import EarlyStopCallback

class MyDataset:
"""Dataset"""
def __init__(self):
    self.data = np.random.randn(20, 3).astype(np.float32)
    self.label = list(np.random.choice([0, 1]).astype(np.float32) for i in range(20))
    self.length = list(np.random.choice([0, 1]).astype(np.float32) for i in range(20))
def __getitem__(self, index):
    return self.data[index], self.label[index], self.length[index]
def __len__(self):
    return len(self.data)

class MyModel(nn.Cell):
    """Model"""
    def __init__(self):
        super().__init__()
        self.fc = nn.Dense(3, 1)
    def construct(self, data):
        output = self.fc(data)
        return output

# Define Dataset
dataset_generator = MyDataset()
train_dataset = ds.GeneratorDataset(dataset_generator, ["data", "label", "length"], shuffle=False)
eval_dataset = ds.GeneratorDataset(dataset_generator, ["data", "label", "length"], shuffle=False)
train_dataset = train_dataset.batch(4)
eval_dataset = eval_dataset.batch(4)
# Define Model
net = MyModel()
net.update_parameters_name('net.')
# Define Loss function
loss_fn = nn.MSELoss()
# Define Optimizer
optimizer = nn.Adam(net.trainable_params(), learning_rate=0.01)
# Define Callback
timer_callback = TimerCallback(print_steps=2)
# Define Trainer
trainer = Trainer(network=net, train_dataset=train_dataset, eval_dataset=eval_dataset,
                  epochs=6, optimizer=optimizer, loss_fn=loss_fn, callbacks=timer_callback)
# Run Trainer
trainer.run(tgt_columns='label', jit=True)

MindNLP中的Callback

MindNLP提供几种常见的Callback,包括TimerCallback,EarlyStopCallback,BestModelCallback等。以上callback的详细内容,可以参考mindnlp.engine.callbacks。

from mindnlp.engine.callbacks import TimerCallback, EarlyStopCallback, BestModelCallback, CheckpointCallback

callbacks = [
    TimerCallback(print_steps=2),
    EarlyStopCallback(patience=2),
    BestModelCallback(save_path='save/callback/best_model', auto_load=True),
    CheckpointCallback(save_path='save/callback/ckpt_files', epochs=2,
                       keep_checkpoint_max=2)
]

自定义Callback

这里我们用一个简单的Callback作为例子,它的功能是在模型训练中每个Epoch结束时,打印出当前的loss均值。

创建回调函数

若要自定义Callback,我们需要在继承Callback基类的基础上实现一个类。这里我们定义这个类为MyCallback,继承自mindnlp.abc.callback。

指定Callback调用的阶段

Callback中所有的类方法都会在Trainer的训练中在特定的阶段调用。如train_begin()会在训练开始时被调用,epoch_end()会在每个epoch结束时调用。具体有哪些类方法,参见Callback文档。这里,MyCallback在每个epoch结束时调用epoch_end(),输出当前epoch结束时的loss均值。

访问Engine内部信息

Callback中所有的类方法都包含run_context参数用来访问Engine内部训练信息。如当前训练的step数目,当前训练的epoch数目,loss值等。这里,MyCallback需要获得Trainer当前的epoch数目和每个epoch结束后的loss均值。

from mindspore import logging
from mindnlp.abc import Callback

class MyCallBack(Callback):
    def __init__(self):
        self.epoch = run_context.cur_epoch_nums
        self.loss = 0

    def epoch_end(self, run_context):
        self.loss = run_context.loss
        logging.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)

my_callback = MyCallBack()
trainer = Trainer(network=net, train_dataset=train_dataset, eval_dataset=eval_dataset,
                  epochs=6, optimizer=optimizer, loss_fn=loss_fn, callbacks=my_callback)
trainer.run(tgt_columns='label', jit=True)