|
|
|
@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore import context, Parameter
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint
|
|
|
|
|
|
|
|
|
|
from src.dataset import load_dataset
|
|
|
|
|
from .gnmt import GNMT
|
|
|
|
@ -37,18 +36,6 @@ context.set_context(
|
|
|
|
|
reserve_class_name_in_scope=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_weight_and_variable(model_path, params):
|
|
|
|
|
print("model path is {}".format(model_path))
|
|
|
|
|
ms_ckpt = load_checkpoint(model_path)
|
|
|
|
|
with open("variable.txt", "w") as f:
|
|
|
|
|
for msname in ms_ckpt:
|
|
|
|
|
f.write(msname + "\n")
|
|
|
|
|
with open("weights.txt", "w") as f:
|
|
|
|
|
for param in params:
|
|
|
|
|
name = param.name
|
|
|
|
|
f.write(name + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GNMTInferCell(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Encapsulation class of GNMT network infer.
|
|
|
|
@ -92,38 +79,31 @@ def gnmt_infer(config, dataset):
|
|
|
|
|
use_one_hot_embeddings=False)
|
|
|
|
|
|
|
|
|
|
params = tfm_model.trainable_params()
|
|
|
|
|
get_weight_and_variable(config.existed_ckpt, params)
|
|
|
|
|
weights = load_infer_weights(config)
|
|
|
|
|
|
|
|
|
|
for param in params:
|
|
|
|
|
value = param.data
|
|
|
|
|
name = param.name
|
|
|
|
|
if name not in weights:
|
|
|
|
|
raise ValueError(f"{name} is not found in weights.")
|
|
|
|
|
with open("weight_after_deal.txt", "a+") as f:
|
|
|
|
|
weights_name = name
|
|
|
|
|
f.write(weights_name)
|
|
|
|
|
f.write("\n")
|
|
|
|
|
if isinstance(value, Tensor):
|
|
|
|
|
print(name, value.asnumpy().shape)
|
|
|
|
|
if weights_name in weights:
|
|
|
|
|
assert weights_name in weights
|
|
|
|
|
if isinstance(weights[weights_name], Parameter):
|
|
|
|
|
if param.data.dtype == "Float32":
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
|
|
|
|
|
elif param.data.dtype == "Float16":
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
|
|
|
|
|
|
|
|
|
|
elif isinstance(weights[weights_name], Tensor):
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
|
|
|
|
elif isinstance(weights[weights_name], np.ndarray):
|
|
|
|
|
param.set_data(Tensor(weights[weights_name], config.dtype))
|
|
|
|
|
else:
|
|
|
|
|
param.set_data(weights[weights_name])
|
|
|
|
|
weights_name = param.name
|
|
|
|
|
if weights_name not in weights:
|
|
|
|
|
raise ValueError(f"{weights_name} is not found in weights.")
|
|
|
|
|
if isinstance(value, Tensor):
|
|
|
|
|
if weights_name in weights:
|
|
|
|
|
assert weights_name in weights
|
|
|
|
|
if isinstance(weights[weights_name], Parameter):
|
|
|
|
|
if param.data.dtype == "Float32":
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
|
|
|
|
|
elif param.data.dtype == "Float16":
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
|
|
|
|
|
|
|
|
|
|
elif isinstance(weights[weights_name], Tensor):
|
|
|
|
|
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
|
|
|
|
|
elif isinstance(weights[weights_name], np.ndarray):
|
|
|
|
|
param.set_data(Tensor(weights[weights_name], config.dtype))
|
|
|
|
|
else:
|
|
|
|
|
print("weight not found in checkpoint: " + weights_name)
|
|
|
|
|
param.set_data(zero_weight(value.asnumpy().shape))
|
|
|
|
|
f.close()
|
|
|
|
|
param.set_data(weights[weights_name])
|
|
|
|
|
else:
|
|
|
|
|
print("weight not found in checkpoint: " + weights_name)
|
|
|
|
|
param.set_data(zero_weight(value.asnumpy().shape))
|
|
|
|
|
|
|
|
|
|
print(" | Load weights successfully.")
|
|
|
|
|
tfm_infer = GNMTInferCell(tfm_model)
|
|
|
|
|
model = Model(tfm_infer)
|
|
|
|
|