第一个模型
概述
情感分类是自然语言处理中的经典任务,目标在于对给定文本中的情感色彩进行挖掘、分析,是典型的分类问题。本节使用MindNLP实现一个基于RNN网络的情感分类模型,实现如下的效果:
Input: This film is terrible
Correct label: Negative
Forecast label: Negative
Input: This film is great
Correct label: Positive
Forecast label: Positive
模型构建
根据任务的模型架构,使用 Seq2vecModel 进行模型构建。模块 Seq2vecModel 的功能是提取输入序列数据的语义特征并计算得到结果向量。这一模块由 encoder 和 head 两部分组成,其中 encoder 将输入句子映射为语义向量,而 head 对 encoder 的输出进行进一步计算得到最终的结果。
from mindnlp.abc import Seq2vecModel
class SentimentClassification(Seq2vecModel):
"""
Sentiment Classification model
"""
def __init__(self, encoder, head):
super().__init__(encoder, head)
self.encoder = encoder
self.head = head
def construct(self, text):
_, (hidden, _), _ = self.encoder(text)
output = self.head(hidden)
return output
模型实例化
分别将 encoder 和 head 两个模块分别初始化,并作为参数传入模型。我们使用MindNLP提供的 RNNEncoder 作为模型的 encoder ,并使用自定义的模块作为模型的 head 。
import math
from mindspore import nn
from mindspore import ops
from mindspore.common.initializer import Uniform, HeUniform
from mindnlp.modules import Glove
from mindnlp.modules import RNNEncoder
class Head(nn.Cell):
"""
Head for Sentiment Classification model
"""
def __init__(self, hidden_dim, output_dim, dropout):
super().__init__()
weight_init = HeUniform(math.sqrt(5))
bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(1 - dropout)
def construct(self, context):
context = ops.concat((context[-2, :, :], context[-1, :, :]), axis=1)
context = self.dropout(context)
return self.sigmoid(self.fc(context))
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
drop = 0.5
lr = 0.001
embedding, vocab = Glove.from_pretrained('6B', 100, special_tokens=["<unk>", "<pad>"], dropout=drop)
lstm_layer = nn.LSTM(100, hidden_size, num_layers=num_layers, batch_first=True,
dropout=drop, bidirectional=bidirectional)
sentiment_encoder = RNNEncoder(embedding, lstm_layer)
sentiment_head = Head(hidden_size, output_size, drop)
net = SentimentClassification(sentiment_encoder, sentiment_head)