!10660 Delete the redundant file of GNMTv2

From: @gaojing22
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @wuxuejian
pull/10660/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 237faca57e

@ -65,33 +65,28 @@ if __name__ == '__main__':
for param in params:
value = param.data
name = param.name
if name not in weights:
raise ValueError(f"{name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f:
weights_name = name
f.write(weights_name)
f.write("\n")
if isinstance(value, Tensor):
print(name, value.asnumpy().shape)
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
weights_name = param.name
if weights_name not in weights:
raise ValueError(f"{weights_name} is not found in weights.")
if isinstance(value, Tensor):
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
f.close()
param.set_data(weights[weights_name])
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model)
tfm_infer.set_train(False)

@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore import context, Parameter
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from src.dataset import load_dataset
from .gnmt import GNMT
@ -37,18 +36,6 @@ context.set_context(
reserve_class_name_in_scope=False)
def get_weight_and_variable(model_path, params):
print("model path is {}".format(model_path))
ms_ckpt = load_checkpoint(model_path)
with open("variable.txt", "w") as f:
for msname in ms_ckpt:
f.write(msname + "\n")
with open("weights.txt", "w") as f:
for param in params:
name = param.name
f.write(name + "\n")
class GNMTInferCell(nn.Cell):
"""
Encapsulation class of GNMT network infer.
@ -92,38 +79,31 @@ def gnmt_infer(config, dataset):
use_one_hot_embeddings=False)
params = tfm_model.trainable_params()
get_weight_and_variable(config.existed_ckpt, params)
weights = load_infer_weights(config)
for param in params:
value = param.data
name = param.name
if name not in weights:
raise ValueError(f"{name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f:
weights_name = name
f.write(weights_name)
f.write("\n")
if isinstance(value, Tensor):
print(name, value.asnumpy().shape)
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
weights_name = param.name
if weights_name not in weights:
raise ValueError(f"{weights_name} is not found in weights.")
if isinstance(value, Tensor):
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
f.close()
param.set_data(weights[weights_name])
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model)
model = Model(tfm_infer)

@ -37,36 +37,26 @@ def load_infer_weights(config):
ms_ckpt = load_checkpoint(model_path)
is_npz = False
weights = {}
with open("variable_after_deal.txt", "w") as f:
for param_name in ms_ckpt:
infer_name = param_name.replace("gnmt.gnmt.", "")
if infer_name.startswith("embedding_lookup."):
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
for param_name in ms_ckpt:
infer_name = param_name.replace("gnmt.gnmt.", "")
if infer_name.startswith("embedding_lookup."):
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
f.close()
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
return weights

Loading…
Cancel
Save