mindnlp.engine.evaluator 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Evaluator for testing.
"""
from inspect import signature
from tqdm import tqdm
from mindspore import log, mutable
from mindnlp import ms_jit
from mindnlp.abc import Metric
from mindnlp.engine.callbacks.callback_manager import CallbackManager, RunContext


[文档]class Evaluator: r""" Evaluator to test the model. Args: network (Cell): A network for evaluating. eval_dataset (Dataset): A evaluating dataset iterator. batc_size (int): numbers of samples in each batch. metrcis (Optional[list[Metric], Metric]): List of metric objects which should be used while evaluating. Default:None. callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed while training. Default: None. jit (bool): Whether use Just-In-Time compile. """ def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None, jit=False): self.network = network self.callbacks = callbacks self.earlystop = False self._check_metric_type(metrics) self.eval_dataset = eval_dataset self.total = eval_dataset.get_dataset_size() self.callback_manager = CallbackManager(callbacks=self.callbacks) self.eval_func = self._prepare_eval_func(network, jit) def _prepare_eval_func(self, network, jit): def _run_step(inputs): """Core process of each step.""" outputs = network(*inputs) return outputs if jit: return ms_jit(_run_step) return _run_step def _check_metric_type(self, metrics): """Check metrics type.""" self.metrics = [] if not metrics: raise ValueError("For Evaluator, the model argument 'metrics' can not be None or empty, " "you should set the argument 'metrics' for model.") if isinstance(metrics, Metric): self.metrics.append(metrics) elif isinstance(metrics, list): if all(isinstance(mc, Metric) for mc in metrics) is True: self.metrics = metrics else: obj = [not isinstance(mc, Metric) for mc in metrics][0] raise TypeError(f"Expect sub-classes of Metrics. Got {type(obj)}") else: raise TypeError(f"Expect metrics to be list or Metrics. Got {type(metrics)}.") def _check_amp_level_arg(self, optimizer, amp_level): """Check mixed-precision argument rules.""" raise NotImplementedError def _check_for_graph_cell(self, kwargs): """Check network rules of GraphCell.""" raise NotImplementedError def _build_boost_network(self, *kwargs): """Build boost network.""" raise NotImplementedError def _check_reuse_dataset(self, dataset): """Check if dataset is being used by other models under the data sink mode.""" if not hasattr(dataset, '__model_hash__'): dataset.__model_hash__ = hash(self) if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self): raise RuntimeError("The dataset object had been used in other model by model.train(...), " "please create a new dataset.")
[文档] def run(self, tgt_columns=None): """ Evaluating function entry. Args: tgt_columns (Optional[list[str], str]): Target label column names for loss function. """ args_dict = vars(self) run_context = RunContext(args_dict) self.callback_manager.evaluate_begin(run_context) self.clear_metrics() _ = self._run(tgt_columns) self.callback_manager.evaluate_end(run_context) self.earlystop = getattr(run_context, 'earlystop', False)
def _run(self, tgt_columns=None): """Evaluating process for non-data sinking mode. The data would be passed to network directly.""" self.network.set_train(False) with tqdm(total=self.total) as progress: progress.set_description('Evaluate') for data in self.eval_dataset.create_dict_iterator(): inputs, tgts = self._data_process(data, tgt_columns) outputs = self.eval_func(inputs) self._update_metrics(outputs, *tgts) progress.update(1) progress.close() metrics_result, metrics_names, metrics_values = self._get_metrics() print(f'Evaluate Score: {metrics_result}') return metrics_result, metrics_names, metrics_values def _run_ds_sink(self): """Evaluating process for data sinking mode.""" raise NotImplementedError def _get_metrics(self): """Get all metrics values.""" metrics = {} metrics_names = [] metrics_values = [] for metric in self.metrics: key = metric.get_metric_name() metrics_names.append(key) metrics[key] = metric.eval() metrics_values.append(metrics[key]) return metrics, metrics_names, metrics_values
[文档] def clear_metrics(self): """Clear metrics values.""" for metric in self.metrics: metric.clear()
def _update_metrics(self, outputs, *tgts): """Update metrics values.""" for metric in self.metrics: metric.update(outputs, *tgts) return True def _data_process(self, data, tgt_columns): """Process data match the network construct""" # prepare input dataset. sig = signature(self.network.construct) net_args = sig.parameters inputs = () for arg in net_args: if arg == 'self': continue if arg not in data.keys(): if str(net_args[arg])[-4:] == 'None': continue inputs = inputs + (data[arg],) # process target dataset. tgt_columns = self._prepare_tgt_columns(tgt_columns) tgts = () for tgt_column in tgt_columns: tgts = tgts + (data[tgt_column],) return mutable(inputs), mutable(tgts) def _prepare_tgt_columns(self, tgt_columns): """Check and prepare target columns for training.""" out_columns = [] if tgt_columns is None: log.warning("In the process of training model, tgt_column can not be None.") return [] if isinstance(tgt_columns, str): out_columns.append(tgt_columns) elif isinstance(tgt_columns, list): if all(isinstance(tgt_column, str) for tgt_column in tgt_columns) is True: out_columns = tgt_columns else: obj = [not isinstance(tgt_column, str) for tgt_column in tgt_columns][0] raise TypeError(f"Expect str of tgt_column. Got {type(obj)}") else: raise TypeError(f"Expect tgt_columns to be list or str. Got {type(tgt_columns)}.") return out_columns