mindnlp.abc.backbones.base 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Basic class for models"""

import mindspore
from mindspore import nn


[文档]class BaseModel(nn.Cell): r"""Basic class for models""" @classmethod def add_attrs(cls, parser): """Add model-specific arguments to the parser.""" raise NotImplementedError def get_targets(self, sample): """ Get targets from either the sample or the net's output. Args: sample (dict): The sample or the net's output. Returns: Tensor, the net's target. """ return sample["target"] def load_parameters(self, parameter_dict, strict_load, *args, **kwargs): """ Copies parameters and buffers from parameter_dict into this module and its descendants. Args: parameter_dict (dict): A dictionary consisting of key: parameters's name, value: parameter. strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter into net when parameter name's suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. """ raise NotImplementedError def get_context(self, src_tokens, mask=None): """ Get Context from encoder. Args: src_tokens (Tensor): Tokens of source sentences. mask (Tensor): Its elements identify whether the corresponding input token is padding or not. If True, not padding token. If False, padding token. Defaults to None. """ raise NotImplementedError @classmethod def load_pretrained(cls, model_name_or_path, ckpt_file_name="model.ckpt", data_name_or_path="."): """Load model from a pretrained model file""" raise NotImplementedError def load_checkpoint(self, ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None): """ Load checkpoint from a specified file. Args: ckpt_file_name (str): Checkpoint file name. net (Cell): The network where the parameters will be loaded. Default: None strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter into net when parameter name's suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. Default: False. filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix will not be loaded. Default: None. dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption is not required. Default: None. dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix will be loaded. Default: None. """ mindspore.load_checkpoint(ckpt_file_name, net=net, strict_load=strict_load, filter_prefix=filter_prefix, dec_key=dec_key, dec_mode=dec_mode, specify_prefix=specify_prefix) def save_checkpoint(self, save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"): """ Save checkpoint to a specified file. Args: save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name would be string, and the type of param_data would be parameter or Tensor). ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: False append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None. enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption is not required. Default: None. enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. """ mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=integrated_save, async_save=async_save, append_dict=append_dict, enc_key=enc_key, enc_mode=enc_mode)