diff --git a/model_zoo/official/nlp/gnmt_v2/export.py b/model_zoo/official/nlp/gnmt_v2/export.py index e50b722847..822ebee1c5 100644 --- a/model_zoo/official/nlp/gnmt_v2/export.py +++ b/model_zoo/official/nlp/gnmt_v2/export.py @@ -65,33 +65,28 @@ if __name__ == '__main__': for param in params: value = param.data - name = param.name - if name not in weights: - raise ValueError(f"{name} is not found in weights.") - with open("weight_after_deal.txt", "a+") as f: - weights_name = name - f.write(weights_name) - f.write("\n") - if isinstance(value, Tensor): - print(name, value.asnumpy().shape) - if weights_name in weights: - assert weights_name in weights - if isinstance(weights[weights_name], Parameter): - if param.data.dtype == "Float32": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) - elif param.data.dtype == "Float16": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) - - elif isinstance(weights[weights_name], Tensor): - param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) - elif isinstance(weights[weights_name], np.ndarray): - param.set_data(Tensor(weights[weights_name], config.dtype)) - else: - param.set_data(weights[weights_name]) + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) else: - print("weight not found in checkpoint: " + weights_name) - param.set_data(zero_weight(value.asnumpy().shape)) - f.close() + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + print(" | Load weights successfully.") tfm_infer = GNMTInferCell(tfm_model) tfm_infer.set_train(False) diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py index 079a654a19..2c62caf517 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore import context, Parameter from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint from src.dataset import load_dataset from .gnmt import GNMT @@ -37,18 +36,6 @@ context.set_context( reserve_class_name_in_scope=False) -def get_weight_and_variable(model_path, params): - print("model path is {}".format(model_path)) - ms_ckpt = load_checkpoint(model_path) - with open("variable.txt", "w") as f: - for msname in ms_ckpt: - f.write(msname + "\n") - with open("weights.txt", "w") as f: - for param in params: - name = param.name - f.write(name + "\n") - - class GNMTInferCell(nn.Cell): """ Encapsulation class of GNMT network infer. @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): use_one_hot_embeddings=False) params = tfm_model.trainable_params() - get_weight_and_variable(config.existed_ckpt, params) weights = load_infer_weights(config) - for param in params: value = param.data - name = param.name - if name not in weights: - raise ValueError(f"{name} is not found in weights.") - with open("weight_after_deal.txt", "a+") as f: - weights_name = name - f.write(weights_name) - f.write("\n") - if isinstance(value, Tensor): - print(name, value.asnumpy().shape) - if weights_name in weights: - assert weights_name in weights - if isinstance(weights[weights_name], Parameter): - if param.data.dtype == "Float32": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) - elif param.data.dtype == "Float16": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) - - elif isinstance(weights[weights_name], Tensor): - param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) - elif isinstance(weights[weights_name], np.ndarray): - param.set_data(Tensor(weights[weights_name], config.dtype)) - else: - param.set_data(weights[weights_name]) + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) else: - print("weight not found in checkpoint: " + weights_name) - param.set_data(zero_weight(value.asnumpy().shape)) - f.close() + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + print(" | Load weights successfully.") tfm_infer = GNMTInferCell(tfm_model) model = Model(tfm_infer) diff --git a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py index a29c9b4d12..9add9b5def 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py +++ b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py @@ -37,36 +37,26 @@ def load_infer_weights(config): ms_ckpt = load_checkpoint(model_path) is_npz = False weights = {} - with open("variable_after_deal.txt", "w") as f: - for param_name in ms_ckpt: - infer_name = param_name.replace("gnmt.gnmt.", "") - if infer_name.startswith("embedding_lookup."): - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - infer_name = "beam_decoder.decoder." + infer_name - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - continue - - elif not infer_name.startswith("gnmt_encoder"): - if infer_name.startswith("gnmt_decoder."): - infer_name = infer_name.replace("gnmt_decoder.", "decoder.") - infer_name = "beam_decoder.decoder." + infer_name - + for param_name in ms_ckpt: + infer_name = param_name.replace("gnmt.gnmt.", "") + if infer_name.startswith("embedding_lookup."): if is_npz: weights[infer_name] = ms_ckpt[param_name] else: weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - - f.close() + infer_name = "beam_decoder.decoder." + infer_name + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() + continue + elif not infer_name.startswith("gnmt_encoder"): + if infer_name.startswith("gnmt_decoder."): + infer_name = infer_name.replace("gnmt_decoder.", "decoder.") + infer_name = "beam_decoder.decoder." + infer_name + + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() return weights