机器翻译
机器翻译就是将一种语言(一句话或者一段话或者一篇文章)翻译成另外一种语言。下面是一个使用Multi30k数据集和Seq2Seq模型训练机器翻译的demo:
定义模型
机器翻译是一种典型的Seq2Seq模型,是从一个序列生成另外一个序列。它涉及两个过程:一个是理解前一个序列,另一个是用理解到的内容来生成新的序列。至于序列所采用的模型可以是RNN,LSTM,GRU,其它序列模型等。
from mindnlp.abc import Seq2seqModel
class MachineTranslation(Seq2seqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.encoder = encoder
self.decoder = decoder
def construct(self, en, de):
encoder_out = self.encoder(en)
decoder_out = self.decoder(de, encoder_out=encoder_out)
output = decoder_out[0]
return output.swapaxes(1,2)
定义超参数
以下是模型训练过程中需要的一些超参数。
enc_emb_dim = 256
dec_emb_dim = 256
enc_hid_dim = 512
dec_hid_dim = 512
enc_dropout = 0.5
dec_dropout = 0.5
数据预处理
通过调用mindnlp中dataset的接口下载并预处理数据集。
加载数据集:
from mindnlp.dataset import load
multi30k_train, multi30k_valid, multi30k_test = load('multi30k')
初始化词表以进行预处理:
from mindnlp.dataset.transforms import BasicTokenizer
from mindspore.dataset import text
from mindnlp.dataset import process
tokenizer = BasicTokenizer(True) # Tokenizer
multi30k_train = multi30k_train.map([tokenizer], 'en')
multi30k_train = multi30k_train.map([tokenizer], 'de')
en_vocab = text.Vocab.from_dataset(multi30k_train, 'en', special_tokens=['<pad>', '<unk>'], special_first= True) # en
de_vocab = text.Vocab.from_dataset(multi30k_train, 'de', special_tokens=['<pad>', '<unk>'], special_first= True) # de
vocab = {'en':en_vocab, 'de':de_vocab}
multi30k_train = process('multi30k', multi30k_train, vocab=vocab, batch_size=64, max_len = 32, drop_remainder = False)
multi30k_valid = multi30k_valid.map([tokenizer], 'en')
multi30k_valid = multi30k_valid.map([tokenizer], 'de')
multi30k_valid = process('multi30k', multi30k_valid, vocab=vocab, batch_size=64, max_len = 32, drop_remainder = False)
实例化模型
from mindspore import nn
from mindnlp.modules import RNNEncoder, RNNDecoder
input_dim = len(en_vocab.vocab())
output_dim = len(de_vocab.vocab())
# encoder
en_embedding = nn.Embedding(input_dim, enc_emb_dim)
en_rnn = nn.RNN(enc_emb_dim, hidden_size=enc_hid_dim, num_layers=2, has_bias=True,
batch_first=True, dropout=enc_dropout, bidirectional=False)
rnn_encoder = RNNEncoder(en_embedding, en_rnn)
# decoder
de_embedding = nn.Embedding(output_dim, dec_emb_dim)
input_feed_size = 0 if enc_hid_dim == 0 else dec_hid_dim
rnns = [
nn.RNNCell(
input_size=dec_emb_dim + input_feed_size
if layer == 0
else dec_hid_dim,
hidden_size=dec_hid_dim
)
for layer in range(2)
]
rnn_decoder = RNNDecoder(de_embedding, rnns, dropout_in=enc_dropout, dropout_out = dec_dropout,attention=True, encoder_output_units=enc_hid_dim)
net = MachineTranslation(rnn_encoder, rnn_decoder)
net.update_parameters_name('net.')
定义优化器,损失函数,回调函数,指标:
from mindnlp.engine.callbacks.timer_callback import TimerCallback
from mindnlp.engine.callbacks.earlystop_callback import EarlyStopCallback
from mindnlp.engine.callbacks.best_model_callback import BestModelCallback
from mindnlp.engine.metrics import Accuracy
optimizer = nn.Adam(net.trainable_params(), learning_rate=10e-5)
loss_fn = nn.CrossEntropyLoss()
# define callbacks
timer_callback_epochs = TimerCallback(print_steps=-1)
earlystop_callback = EarlyStopCallback(patience=2)
bestmodel_callback = BestModelCallback()
callbacks = [timer_callback_epochs, earlystop_callback, bestmodel_callback]
# define metrics
metric = Accuracy()
定义训练步骤
from mindnlp.engine.trainer import Trainer
trainer = Trainer(network=net, train_dataset=multi30k_train, eval_dataset=multi30k_valid, metrics=metric,
epochs=10, loss_fn=loss_fn, optimizer=optimizer)
训练过程
trainer.run(tgt_columns="de", jit=True)
print("end train")
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [05:39<00:00, 1.34it/s, loss=3.2271016]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.49it/s]
Evaluate Score: {'Accuracy': 0.6223496055226825}
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:28<00:00, 5.13it/s, loss=2.1794753]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.50it/s]
Evaluate Score: {'Accuracy': 0.6646942800788954}
Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:28<00:00, 5.12it/s, loss=1.8816497]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.39it/s]
Evaluate Score: {'Accuracy': 0.6863597140039448}
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:28<00:00, 5.11it/s, loss=1.6710395]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.39it/s]
Evaluate Score: {'Accuracy': 0.7070081360946746}
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.10it/s, loss=1.5266166]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.39it/s]
Evaluate Score: {'Accuracy': 0.7174248027613412}
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.10it/s, loss=1.4266685]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.38it/s]
Evaluate Score: {'Accuracy': 0.7320019723865878}
Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.09it/s, loss=1.3493056]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.37it/s]
Evaluate Score: {'Accuracy': 0.7478427021696252}
Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.09it/s, loss=1.2893807]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.38it/s]
Evaluate Score: {'Accuracy': 0.766857741617357}
Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.09it/s, loss=1.2387483]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.40it/s]
Evaluate Score: {'Accuracy': 0.777120315581854}
Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 454/454 [01:29<00:00, 5.09it/s, loss=1.1957376]
Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:11<00:00, 1.38it/s]
Evaluate Score: {'Accuracy': 0.782482741617357}
end train