mindnlp.engine.metrics.distinct 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""""Class for Metric Distinct"""


from mindnlp.abc import Metric
from mindnlp.common.metrics import _check_value_type


[文档]class Distinct(Metric): """ Calculates the Distinct-N. Distinct-N is a metric that measures the diversity of a sentence. It focuses on the number of distinct n-gram of a sentence. The larger the number of distinct n-grams, the higher the diversity of the text. The function is shown as follows: Args: n_size (int): N_gram value. Defaults: 2. name (str): Name of the metric. Example: >>> from mindnlp.common.metrics import Distinct >>> cand_list = ["The", "cat", "The", "cat", "on", "the", "mat"] >>> metric = Distinct() >>> metric.update(cand_list) >>> distinct_score = metric.eval() >>> print(distinct_score) 0.8333333333333334 """ def __init__(self, n_size=2, name='Distinct'): super().__init__() self._name = name self.n_size = _check_value_type("n_size", n_size, [int]) self.diff_ngram = set() self.count = 0.0
[文档] def clear(self): """Clears the internal evaluation results.""" self.diff_ngram = set() self.count = 0.0
[文档] def update(self, *inputs): """ Updates local variables. Args: inputs: Input `cand_list`. - cand_list (list): A list of tokenized candidate sentence. Raises: ValueError: If the number of inputs is not 1. """ if len(inputs) != 1: raise ValueError(f'For `Distinct.update`, it needs 1 inputs (`cand_list`), ' f'but got {len(inputs)}.') cand_list = inputs[0] cand_list = _check_value_type("cand_list", cand_list, list) for i in range(0, len(cand_list) - self.n_size + 1): ngram = ' '.join(cand_list[i:(i + self.n_size)]) self.count += 1 self.diff_ngram.add(ngram)
[文档] def eval(self): """ Computes and returns the Distinct-N. Returns: - **distinct_score** (float) - The computed result. """ distinct_score = len(self.diff_ngram) / self.count return distinct_score
[文档] def get_metric_name(self): """ Returns the name of the metric. """ return self._name