!10660 Delete the redundant file of GNMTv2

From: @gaojing22
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @wuxuejian
pull/10660/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 237faca57e

@ -65,33 +65,28 @@ if __name__ == '__main__':
for param in params: for param in params:
value = param.data value = param.data
name = param.name weights_name = param.name
if name not in weights: if weights_name not in weights:
raise ValueError(f"{name} is not found in weights.") raise ValueError(f"{weights_name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f: if isinstance(value, Tensor):
weights_name = name if weights_name in weights:
f.write(weights_name) assert weights_name in weights
f.write("\n") if isinstance(weights[weights_name], Parameter):
if isinstance(value, Tensor): if param.data.dtype == "Float32":
print(name, value.asnumpy().shape) param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
if weights_name in weights: elif param.data.dtype == "Float16":
assert weights_name in weights param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32": elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif param.data.dtype == "Float16": elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) param.set_data(Tensor(weights[weights_name], config.dtype))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else: else:
print("weight not found in checkpoint: " + weights_name) param.set_data(weights[weights_name])
param.set_data(zero_weight(value.asnumpy().shape)) else:
f.close() print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.") print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model) tfm_infer = GNMTInferCell(tfm_model)
tfm_infer.set_train(False) tfm_infer.set_train(False)

@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context, Parameter from mindspore import context, Parameter
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from src.dataset import load_dataset from src.dataset import load_dataset
from .gnmt import GNMT from .gnmt import GNMT
@ -37,18 +36,6 @@ context.set_context(
reserve_class_name_in_scope=False) reserve_class_name_in_scope=False)
def get_weight_and_variable(model_path, params):
print("model path is {}".format(model_path))
ms_ckpt = load_checkpoint(model_path)
with open("variable.txt", "w") as f:
for msname in ms_ckpt:
f.write(msname + "\n")
with open("weights.txt", "w") as f:
for param in params:
name = param.name
f.write(name + "\n")
class GNMTInferCell(nn.Cell): class GNMTInferCell(nn.Cell):
""" """
Encapsulation class of GNMT network infer. Encapsulation class of GNMT network infer.
@ -92,38 +79,31 @@ def gnmt_infer(config, dataset):
use_one_hot_embeddings=False) use_one_hot_embeddings=False)
params = tfm_model.trainable_params() params = tfm_model.trainable_params()
get_weight_and_variable(config.existed_ckpt, params)
weights = load_infer_weights(config) weights = load_infer_weights(config)
for param in params: for param in params:
value = param.data value = param.data
name = param.name weights_name = param.name
if name not in weights: if weights_name not in weights:
raise ValueError(f"{name} is not found in weights.") raise ValueError(f"{weights_name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f: if isinstance(value, Tensor):
weights_name = name if weights_name in weights:
f.write(weights_name) assert weights_name in weights
f.write("\n") if isinstance(weights[weights_name], Parameter):
if isinstance(value, Tensor): if param.data.dtype == "Float32":
print(name, value.asnumpy().shape) param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
if weights_name in weights: elif param.data.dtype == "Float16":
assert weights_name in weights param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32": elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif param.data.dtype == "Float16": elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) param.set_data(Tensor(weights[weights_name], config.dtype))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else: else:
print("weight not found in checkpoint: " + weights_name) param.set_data(weights[weights_name])
param.set_data(zero_weight(value.asnumpy().shape)) else:
f.close() print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
print(" | Load weights successfully.") print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model) tfm_infer = GNMTInferCell(tfm_model)
model = Model(tfm_infer) model = Model(tfm_infer)

@ -37,36 +37,26 @@ def load_infer_weights(config):
ms_ckpt = load_checkpoint(model_path) ms_ckpt = load_checkpoint(model_path)
is_npz = False is_npz = False
weights = {} weights = {}
with open("variable_after_deal.txt", "w") as f: for param_name in ms_ckpt:
for param_name in ms_ckpt: infer_name = param_name.replace("gnmt.gnmt.", "")
infer_name = param_name.replace("gnmt.gnmt.", "") if infer_name.startswith("embedding_lookup."):
if infer_name.startswith("embedding_lookup."):
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name)
f.write("\n")
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz: if is_npz:
weights[infer_name] = ms_ckpt[param_name] weights[infer_name] = ms_ckpt[param_name]
else: else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy() weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
f.write(infer_name) infer_name = "beam_decoder.decoder." + infer_name
f.write("\n") if is_npz:
weights[infer_name] = ms_ckpt[param_name]
f.close() else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
continue
elif not infer_name.startswith("gnmt_encoder"):
if infer_name.startswith("gnmt_decoder."):
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
infer_name = "beam_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[param_name]
else:
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
return weights return weights

Loading…
Cancel
Save