!11491 Add GRU network

From: @qujianwei
Reviewed-by: 
Signed-off-by:
pull/11491/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0beb38fdb0

File diff suppressed because it is too large Load Diff

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer evaluation script."""
import argparse
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context
from src.dataset import create_gru_dataset
from src.seq2seq import Seq2Seq
from src.gru_for_infer import GRUInferCell
from src.config import config
def run_gru_eval():
"""
Transformer evaluation.
"""
parser = argparse.ArgumentParser(description='GRU eval')
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
parser.add_argument("--dataset_path", type=str, default="",
help="Dataset path, default: f`sns.")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
device_id=args.device_id, save_graphs=False)
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config, is_training=False)
network = GRUInferCell(network)
network.set_train(False)
if args.ckpt_file != "":
parameter_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, parameter_dict)
model = Model(network)
predictions = []
source_sents = []
target_sents = []
eval_text_len = 0
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
source_sents.append(batch["source_ids"])
target_sents.append(batch["target_ids"])
source_ids = Tensor(batch["source_ids"], mstype.int32)
target_ids = Tensor(batch["target_ids"], mstype.int32)
predicted_ids = model.predict(source_ids, target_ids)
print("predicts is ", predicted_ids.asnumpy())
print("target_ids is ", target_ids)
predictions.append(predicted_ids.asnumpy())
eval_text_len = eval_text_len + 1
f_output = open(config.output_file, 'w')
f_target = open(config.target_file, "w")
for batch_out, true_sentence in zip(predictions, target_sents):
for i in range(config.eval_batch_size):
target_ids = [str(x) for x in true_sentence[i].tolist()]
f_target.write(" ".join(target_ids) + "\n")
token_ids = [str(x) for x in batch_out[i].tolist()]
f_output.write(" ".join(token_ids) + "\n")
f_output.close()
f_target.close()
if __name__ == "__main__":
run_gru_eval()

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh create_dataset.sh DATASET_PATH OUTPUT_PATH"
echo "for example: sh create_dataset.sh /path/multi30k/ /path/multi30k/mindrecord/"
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ulimit -u unlimited
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not valid"
exit 1
fi
OUTPUT_PATH=$(get_real_path $2)
echo $OUTPUT_PATH
if [ ! -d $OUTPUT_PATH ]
then
echo "error: OUTPUT_PATH=$OUTPUT_PATH is not valid"
exit 1
fi
paste $DATASET_PATH/train.de.tok $DATASET_PATH/train.en.tok > $DATASET_PATH/train.all
python ../src/create_data.py --input_file $DATASET_PATH/train.all --num_splits 8 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_train_mindrecord --max_seq_length 32 --bucket [32]
paste $DATASET_PATH/test.de.tok $DATASET_PATH/test.en.tok > $DATASET_PATH/test.all
python ../src/create_data.py --input_file $DATASET_PATH/test.all --num_splits 1 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_test_mindrecord --max_seq_length 32 --bucket [32]

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE"
echo "for example: sh parse_output.sh target.txt output.txt vocab.en"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ref_data=$1
eval_output=$2
vocab_file=$3
cat $ref_data \
| python ../src/parse_output.py --vocab_file $vocab_file \
| sed 's/@@ //g' > ${ref_data}.forbleu
cat $eval_output \
| python ../src/parse_output.py --vocab_file $vocab_file \
| sed 's/@@ //g' > ${eval_output}.forbleu

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATASET_DIR=$1
python ../src/preprocess.py --dataset_path=$DATASET_DIR

@ -0,0 +1,68 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log &
cd ..
done

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_eval.sh [CKPT_FILE] [DATASET_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
CKPT_FILE=$(get_real_path $1)
echo $CKPT_FILE
if [ ! -f $CKPT_FILE ]
then
echo "error: CKPT_FILE=$CKPT_FILE is not a file"
exit 1
fi
DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
rm -rf ./eval
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
echo "start eval for device $DEVICE_ID"
env > env.log
python eval.py --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH &> log &
cd ..

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [DATASET_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -f $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
exit 1
fi
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
cd ..

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU config"""
from easydict import EasyDict
config = EasyDict({
"batch_size": 16,
"eval_batch_size": 1,
"src_vocab_size": 8154,
"trg_vocab_size": 6113,
"encoder_embedding_size": 256,
"decoder_embedding_size": 256,
"hidden_size": 512,
"max_length": 32,
"num_epochs": 30,
"save_checkpoint": True,
"ckpt_epoch": 10,
"target_file": "target.txt",
"output_file": "output.txt",
"keep_checkpoint_max": 30,
"base_lr": 0.001,
"warmup_step": 300,
"momentum": 0.9,
"init_loss_scale_value": 1024,
'scale_factor': 2,
'scale_window': 2000,
"warmup_ratio": 1/3.0,
"teacher_force_ratio": 0.5
})

@ -0,0 +1,196 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create training instances for Transformer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import collections
import logging
import numpy as np
import tokenization
from mindspore.mindrecord import FileWriter
class SampleInstance():
"""A single sample instance (sentence pair)."""
def __init__(self, source_tokens, target_tokens):
self.source_tokens = source_tokens
self.target_tokens = target_tokens
def __str__(self):
s = ""
s += "source_tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.source_tokens]))
s += "target tokens: %s\n" % (" ".join(
[tokenization.convert_to_printable(x) for x in self.target_tokens]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def get_instance_features(instance, tokenizer_src, tokenizer_trg, max_seq_length, bucket):
"""Get features from `SampleInstance`s."""
def _find_bucket_length(source_tokens, target_tokens):
source_ids = tokenizer_src.convert_tokens_to_ids(source_tokens)
target_ids = tokenizer_trg.convert_tokens_to_ids(target_tokens)
num = max(len(source_ids), len(target_ids))
assert num <= bucket[-1]
for index in range(1, len(bucket)):
if bucket[index - 1] < num <= bucket[index]:
return bucket[index]
return bucket[0]
def _convert_ids_and_mask(tokenizer, input_tokens, seq_max_bucket_length):
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_mask = [1] * len(input_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < seq_max_bucket_length:
input_ids.append(1)
input_mask.append(0)
assert len(input_ids) == seq_max_bucket_length
assert len(input_mask) == seq_max_bucket_length
return input_ids, input_mask
seq_max_bucket_length = _find_bucket_length(instance.source_tokens, instance.target_tokens)
source_ids, source_mask = _convert_ids_and_mask(tokenizer_src, instance.source_tokens, seq_max_bucket_length)
target_ids, target_mask = _convert_ids_and_mask(tokenizer_trg, instance.target_tokens, seq_max_bucket_length)
features = collections.OrderedDict()
features["source_ids"] = np.asarray(source_ids)
features["source_mask"] = np.asarray(source_mask)
features["target_ids"] = np.asarray(target_ids)
features["target_mask"] = np.asarray(target_mask)
return features, seq_max_bucket_length
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
"""Creates `SampleInstance`s for a single sentence pair."""
EOS = "<eos>"
SOS = "<sos>"
if len(source_words) >= max_seq_length-1 or len(target_words) >= max_seq_length-1:
if clip_to_max_len:
source_words = source_words[:min([len(source_words, max_seq_length-2)])]
target_words = target_words[:min([len(target_words, max_seq_length-2)])]
else:
return None
source_tokens = [SOS] + source_words + [EOS]
target_tokens = [SOS] + target_words + [EOS]
instance = SampleInstance(
source_tokens=source_tokens,
target_tokens=target_tokens)
return instance
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True,
help='Input raw text file (or comma-separated list of files).')
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
parser.add_argument("--num_splits", type=int, default=16,
help='The MindRecord file will be split into the number of partition.')
parser.add_argument("--src_vocab_file", type=str, required=True,
help='The vocabulary file that the Transformer model was trained on.')
parser.add_argument("--trg_vocab_file", type=str, required=True,
help='The vocabulary file that the Transformer model was trained on.')
parser.add_argument("--clip_to_max_len", type=ast.literal_eval, default=False,
help='clip sequences to maximum sequence length.')
parser.add_argument("--max_seq_length", type=int, default=32, help='Maximum sequence length.')
parser.add_argument("--bucket", type=ast.literal_eval, default=[32],
help='bucket sequence length')
args = parser.parse_args()
tokenizer_src = tokenization.WhiteSpaceTokenizer(vocab_file=args.src_vocab_file)
tokenizer_trg = tokenization.WhiteSpaceTokenizer(vocab_file=args.trg_vocab_file)
input_files = []
for input_pattern in args.input_file.split(","):
input_files.append(input_pattern)
logging.info("*** Read from input files ***")
output_file = args.output_file
logging.info("*** Write to output files ***")
logging.info(" %s", output_file)
total_written = 0
total_read = 0
feature_dict = {}
for i in args.bucket:
feature_dict[i] = []
for input_file in input_files:
logging.info("*** Reading from %s ***", input_file)
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
total_read += 1
if total_read % 100000 == 0:
logging.info("Read %d ...", total_read)
if line.strip() == "":
continue
source_line, target_line = line.strip().split("\t")
source_tokens = tokenizer_src.tokenize(source_line)
target_tokens = tokenizer_trg.tokenize(target_line)
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
logging.info("ignore long sentence!")
continue
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
clip_to_max_len=args.clip_to_max_len)
if instance is None:
continue
features, seq_max_bucket_length = get_instance_features(instance, tokenizer_src, tokenizer_trg,
args.max_seq_length, args.bucket)
for key in feature_dict:
if key == seq_max_bucket_length:
feature_dict[key].append(features)
if total_read <= 10:
logging.info("*** Example ***")
logging.info("source tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.source_tokens]))
logging.info("target tokens: %s", " ".join(
[tokenization.convert_to_printable(x) for x in instance.target_tokens]))
for feature_name in features.keys():
feature = features[feature_name]
logging.info("%s: %s", feature_name, feature)
for i in args.bucket:
if args.num_splits == 1:
output_file_name = output_file + '_' + str(i)
else:
output_file_name = output_file + '_' + str(i) + '_'
writer = FileWriter(output_file_name, args.num_splits)
data_schema = {"source_ids": {"type": "int64", "shape": [-1]},
"source_mask": {"type": "int64", "shape": [-1]},
"target_ids": {"type": "int64", "shape": [-1]},
"target_mask": {"type": "int64", "shape": [-1]}
}
writer.add_schema(data_schema, "gru")
features_ = feature_dict[i]
logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
for item in features_:
writer.write_raw_data([item])
total_written += 1
writer.commit()
logging.info("Wrote %d total instances", total_written)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

@ -0,0 +1,48 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py."""
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC
from src.config import config
import numpy as np
de.config.set_seed(1)
def random_teacher_force(source_ids, target_ids, target_mask):
teacher_force = np.random.random() < config.teacher_force_ratio
teacher_force_array = np.array([teacher_force], dtype=bool)
return source_ids, target_ids, teacher_force_array
def create_gru_dataset(epoch_count=1, batch_size=1, rank_size=1, rank_id=0, do_shuffle=True, dataset_path=None,
is_training=True):
"""create dataset"""
ds = de.MindDataset(dataset_path,
columns_list=["source_ids", "target_ids",
"target_mask"],
shuffle=do_shuffle, num_parallel_workers=10, num_shards=rank_size, shard_id=rank_id)
operations = random_teacher_force
ds = ds.map(operations=operations, input_columns=["source_ids", "target_ids", "target_mask"],
output_columns=["source_ids", "target_ids", "teacher_force"],
column_order=["source_ids", "target_ids", "teacher_force"])
type_cast_op = deC.TypeCast(mstype.int32)
type_cast_op_bool = deC.TypeCast(mstype.bool_)
ds = ds.map(operations=type_cast_op, input_columns="source_ids")
ds = ds.map(operations=type_cast_op, input_columns="target_ids")
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(1)
return ds

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU cell"""
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.weight_init import gru_default_state
class BidirectionGRU(nn.Cell):
'''
BidirectionGRU model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(BidirectionGRU, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.embedding_size = config.encoder_embedding_size
self.hidden_size = config.hidden_size
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = gru_default_state(self.batch_size,
self.embedding_size,
self.hidden_size)
self.weight_bw_i, self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, self.init_bw_h = \
gru_default_state(self.batch_size, self.embedding_size, self.hidden_size)
self.reverse = P.ReverseV2(axis=[1])
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=0)
self.rnn = P.DynamicGRUV2()
self.text_len = config.max_length
self.cast = P.Cast()
def construct(self, x):
'''
BidirectionGRU construction
Args:
x(Tensor): BidirectionGRU input
Returns:
output(Tensor): rnn output
hidden(Tensor): hidden state
'''
x = self.cast(x, mstype.float16)
y1, _, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
bw_x = self.reverse(x)
y1_bw, _, _, _, _, _ = self.rnn(bw_x, self.weight_bw_i,
self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, None, self.init_bw_h)
y1_bw = self.reverse(y1_bw)
output = self.concat((y1, y1_bw))
hidden = self.concat((y1[self.text_len-1:self.text_len:1, ::, ::],
y1_bw[self.text_len-1:self.text_len:1, ::, ::]))
hidden = self.squeeze(hidden)
return output, hidden
class GRU(nn.Cell):
'''
GRU model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(GRU, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.embedding_size = config.encoder_embedding_size
self.hidden_size = config.hidden_size
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = \
gru_default_state(self.batch_size, self.embedding_size + self.hidden_size*2, self.hidden_size)
self.rnn = P.DynamicGRUV2()
self.cast = P.Cast()
def construct(self, x):
'''
GRU construction
Args:
x(Tensor): GRU input
Returns:
output(Tensor): rnn output
hidden(Tensor): hidden state
'''
x = self.cast(x, mstype.float16)
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
return y1, h1

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU Infer cell"""
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.config import config
class GRUInferCell(nn.Cell):
'''
GRU infer consturction
Args:
network: gru network
'''
def __init__(self, network):
super(GRUInferCell, self).__init__(auto_prefix=False)
self.network = network
self.argmax = P.ArgMaxWithValue(axis=2)
self.transpose = P.Transpose()
self.teacher_force = Tensor(np.zeros((config.eval_batch_size)), mstype.bool_)
def construct(self,
encoder_inputs,
decoder_inputs):
predict_probs = self.network(encoder_inputs, decoder_inputs, self.teacher_force)
predict_probs = self.transpose(predict_probs, (1, 0, 2))
predict_ids, _ = self.argmax(predict_probs)
return predict_ids

@ -0,0 +1,243 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU train cell"""
from mindspore import Tensor, Parameter, ParameterTuple, context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from src.config import config
from src.loss import NLLLoss
class GRUWithLossCell(nn.Cell):
"""
GRU network connect with loss function.
Args:
network: The training network.
Returns:
the output of loss function.
"""
def __init__(self, network):
super(GRUWithLossCell, self).__init__()
self.network = network
self.loss = NLLLoss()
self.logits_shape = (-1, config.src_vocab_size)
self.reshape = P.Reshape()
self.cast = P.Cast()
self.mean = P.ReduceMean()
self.text_len = config.max_length
self.split = P.Split(axis=0, output_num=config.max_length-1)
self.squeeze = P.Squeeze()
self.add = P.AddN()
self.transpose = P.Transpose()
self.shape = P.Shape()
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
'''
GRU loss cell
Args:
encoder_inputs(Tensor): encoder inputs
decoder_inputs(Tensor): decoder inputs
teacher_force(Tensor): teacher force flag
Returns:
loss(scalar): loss output
'''
logits = self.network(encoder_inputs, decoder_inputs, teacher_force)
logits = self.cast(logits, mstype.float32)
loss_total = ()
decoder_targets = decoder_inputs
decoder_output = logits
for i in range(1, self.text_len):
loss = self.loss(self.squeeze(decoder_output[i-1:i:1, ::, ::]), decoder_targets[:, i])
loss_total += (loss,)
loss = self.add(loss_total) / self.text_len
return loss
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
class ClipGradients(nn.Cell):
"""
Clip gradients.
Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.
Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self,
grads,
clip_type,
clip_value):
"""Defines the gradients clip."""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
t = self.cast(t, dt)
new_grads = new_grads + (t,)
return new_grads
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class GRUTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of GRU network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(GRUTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()
self.cast = P.Cast()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.gpu_target = False
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self,
encoder_inputs,
decoder_inputs,
teacher_force,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(encoder_inputs,
decoder_inputs,
teacher_force)
init = False
if not self.gpu_target:
# alloc status
init = self.alloc_status()
# clear overflow buffer
self.clear_before_grad(init)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(encoder_inputs,
decoder_inputs,
teacher_force,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
if not self.gpu_target:
self.get_status(init)
# sum overflow buffer elements, 0: not overflow, >0: overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
# convert flag_sum to scalar
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NLLLoss cell"""
import mindspore.ops.operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
class NLLLoss(_Loss):
'''
NLLLoss function
'''
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()
self.reduce_sum = P.ReduceSum()
def construct(self, logits, label):
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
return self.get_loss(loss)

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for deeptext"""
import math
def rsqrt_decay(warmup_steps, current_step):
return float(max([current_step, warmup_steps])) ** -0.5
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, total_steps):
decay_steps = total_steps - warmup_steps
linear_decay = (total_steps - current_step) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * current_step / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
learning_rate = decayed * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.base_lr
total_steps = int(base_step * config.num_epochs)
warmup_steps = int(config.warmup_step)
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ids to tokens."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tokenization
# Explicitly set the encoding
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
def main():
parser = argparse.ArgumentParser(
description="recore nbest with smoothed sentence-level bleu.")
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
args = parser.parse_args()
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
for line in sys.stdin:
token_ids = [int(x) for x in line.strip().split()]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
sent = " ".join(tokens)
sent = sent.split("<sos>")[-1]
sent = sent.split("<eos>")[0]
print(sent.strip())
if __name__ == "__main__":
main()

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Dataset preprocess'''
import os
import argparse
from collections import Counter
from nltk.tokenize import word_tokenize
def create_tokenized_sentences(input_files, language):
'''
Create tokenized sentences files.
Args:
input_files: input files.
language: text language
'''
sentence = []
total_lines = open(input_files, "r").read().splitlines()
for line in total_lines:
line = line.strip('\r\n ')
line = line.lower()
tokenize_sentence = word_tokenize(line, language)
str_sentence = " ".join(tokenize_sentence)
sentence.append(str_sentence)
tokenize_file = input_files + ".tok"
f = open(tokenize_file, "w")
for line in sentence:
f.write(line)
f.write("\n")
f.close()
def get_dataset_vocab(text_file, vocab_file):
'''
Create dataset vocab files.
Args:
text_file: dataset text files.
vocab_file: vocab file
'''
counter = Counter()
text_lines = open(text_file, "r").read().splitlines()
for line in text_lines:
for word in line.strip('\r\n ').split(' '):
if word:
counter[word] += 1
vocab = open(vocab_file, "w")
basic_label = ["<unk>", "<pad>", "<sos>", "<eos>"]
for label in basic_label:
vocab.write(label + "\n")
for key, f in sorted(counter.items(), key=lambda x: x[1], reverse=True):
if f < 2:
continue
vocab.write(key + "\n")
vocab.close()
def MergeText(root_dir, file_list, output_file):
'''
Merge text files together.
Args:
root_dir: root dir
file_list: dataset files list.
output_file: output file after merge
'''
output_file = os.path.join(root_dir, output_file)
f_output = open(output_file, "w")
for file_name in file_list:
text_path = os.path.join(root_dir, file_name) + ".tok"
f = open(text_path)
f_output.write(f.read() + "\n")
f_output.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='gru_dataset')
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path, default: f`sns.")
args = parser.parse_args()
dataset_path = args.dataset_path
src_file_list = ["train.de", "test.de", "val.de"]
dst_file_list = ["train.en", "test.en", "val.en"]
for file in src_file_list:
file_path = os.path.join(dataset_path, file)
create_tokenized_sentences(file_path, "english")
for file in dst_file_list:
file_path = os.path.join(dataset_path, file)
create_tokenized_sentences(file_path, "german")
src_all_file = "all.de.tok"
dst_all_file = "all.en.tok"
MergeText(dataset_path, src_file_list, src_all_file)
MergeText(dataset_path, dst_file_list, dst_all_file)
src_vocab = os.path.join(dataset_path, "vocab.de")
dst_vocab = os.path.join(dataset_path, "vocab.en")
get_dataset_vocab(os.path.join(dataset_path, src_all_file), src_vocab)
get_dataset_vocab(os.path.join(dataset_path, dst_all_file), dst_vocab)

@ -0,0 +1,223 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Seq2Seq construction"""
import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.gru import BidirectionGRU, GRU
from src.weight_init import dense_default_state
class Attention(nn.Cell):
'''
Attention model
'''
def __init__(self, config):
super(Attention, self).__init__()
self.text_len = config.max_length
self.attn = nn.Dense(in_channels=config.hidden_size * 3,
out_channels=config.hidden_size).to_float(mstype.float16)
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(mstype.float16)
self.expandims = P.ExpandDims()
self.tanh = P.Tanh()
self.softmax = P.Softmax()
self.tile = P.Tile()
self.transpose = P.Transpose()
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=2)
self.cast = P.Cast()
def construct(self, hidden, encoder_outputs):
'''
Attention construction
Args:
hidden(Tensor): hidden state
encoder_outputs(Tensor): the output of encoder
Returns:
Tensor, attention output
'''
hidden = self.expandims(hidden, 1)
hidden = self.tile(hidden, (1, self.text_len, 1))
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
out = self.concat((hidden, encoder_outputs))
out = self.attn(out)
energy = self.tanh(out)
attention = self.fc(energy)
attention = self.squeeze(attention)
attention = self.cast(attention, mstype.float32)
attention = self.softmax(attention)
attention = self.cast(attention, mstype.float16)
return attention
class Encoder(nn.Cell):
'''
Encoder model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Encoder, self).__init__()
self.hidden_size = config.hidden_size
self.vocab_size = config.src_vocab_size
self.embedding_size = config.encoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = BidirectionGRU(config, is_training=is_training).to_float(mstype.float16)
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(mstype.float16)
self.shape = P.Shape()
self.transpose = P.Transpose()
self.p = P.Print()
self.cast = P.Cast()
self.text_len = config.max_length
self.squeeze = P.Squeeze(axis=0)
self.tanh = P.Tanh()
def construct(self, src):
'''
Encoder construction
Args:
src(Tensor): source sentences
Returns:
output(Tensor): output of rnn
hidden(Tensor): output hidden
'''
embedded = self.embedding(src)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float16)
output, hidden = self.rnn(embedded)
hidden = self.fc(hidden)
hidden = self.tanh(hidden)
return output, hidden
class Decoder(nn.Cell):
'''
Decoder model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Decoder, self).__init__()
self.hidden_size = config.hidden_size
self.vocab_size = config.trg_vocab_size
self.embedding_size = config.decoder_embedding_size
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.rnn = GRU(config, is_training=is_training).to_float(mstype.float16)
self.text_len = config.max_length
self.shape = P.Shape()
self.transpose = P.Transpose()
self.p = P.Print()
self.cast = P.Cast()
self.concat = P.Concat(axis=2)
self.squeeze = P.Squeeze(axis=0)
self.expandims = P.ExpandDims()
self.log_softmax = P.LogSoftmax(axis=1)
weight, bias = dense_default_state(self.embedding_size+self.hidden_size*3, self.vocab_size)
self.fc = nn.Dense(self.embedding_size+self.hidden_size*3, self.vocab_size,
weight_init=weight, bias_init=bias).to_float(mstype.float16)
self.attention = Attention(config)
self.bmm = P.BatchMatMul()
self.dropout = nn.Dropout(0.7)
self.expandims = P.ExpandDims()
def construct(self, inputs, hidden, encoder_outputs):
'''
Decoder construction
Args:
inputs(Tensor): decoder input
hidden(Tensor): hidden state
encoder_outputs(Tensor): encoder output
Returns:
pred_prob(Tensor): decoder predict probility
hidden(Tensor): hidden state
'''
embedded = self.embedding(inputs)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.cast(embedded, mstype.float16)
attn = self.attention(hidden, encoder_outputs)
attn = self.expandims(attn, 1)
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
weight = self.bmm(attn, encoder_outputs)
weight = self.transpose(weight, (1, 0, 2))
emd_con = self.concat((embedded, weight))
output, hidden = self.rnn(emd_con)
out = self.concat((embedded, output, weight))
out = self.squeeze(out)
hidden = self.squeeze(hidden)
prediction = self.fc(out)
prediction = self.dropout(prediction)
prediction = self.cast(prediction, mstype.float32)
prediction = self.cast(prediction, mstype.float32)
pred_prob = self.log_softmax(prediction)
pred_prob = self.expandims(pred_prob, 0)
return pred_prob, hidden
class Seq2Seq(nn.Cell):
'''
Seq2Seq model
Args:
config: config of network
'''
def __init__(self, config, is_training=True):
super(Seq2Seq, self).__init__()
if is_training:
self.batch_size = config.batch_size
else:
self.batch_size = config.eval_batch_size
self.encoder = Encoder(config, is_training=is_training)
self.decoder = Decoder(config, is_training=is_training)
self.expandims = P.ExpandDims()
self.dropout = nn.Dropout()
self.shape = P.Shape()
self.concat = P.Concat(axis=0)
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
self.squeeze = P.Squeeze(axis=0)
self.sos = Tensor(np.ones((self.batch_size, 1))*2, mstype.int32)
self.select = P.Select()
self.text_len = config.max_length
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
'''
Seq2Seq construction
Args:
encoder_inputs(Tensor): encoder input sentences
decoder_inputs(Tensor): decoder input sentences
teacher_force(Tensor): teacher force flag
Returns:
outputs(Tensor): total predict probility
'''
decoder_input = self.sos
encoder_output, hidden = self.encoder(encoder_inputs)
decoder_hidden = hidden
decoder_outputs = ()
for i in range(1, self.text_len):
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output)
decoder_outputs += (decoder_output,)
if self.training:
decoder_input_force = decoder_inputs[::, i:i+1]
decoder_input_top1, _ = self.argmax(self.squeeze(decoder_output))
decoder_input = self.select(teacher_force, decoder_input_force, decoder_input_top1)
else:
decoder_input, _ = self.argmax(self.squeeze(decoder_output))
outputs = self.concat(decoder_outputs)
return outputs

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tokenization utilities."""
import sys
import collections
import unicodedata
def convert_to_printable(text):
"""
Converts `text` to a printable coding format.
"""
if sys.version_info[0] == 3:
if isinstance(text, str):
return text
if isinstance(text, bytes):
return text.decode("utf-8", "ignore")
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
raise ValueError("Only supported when running on Python3.")
def convert_to_unicode(text):
"""
Converts `text` to Unicode format.
"""
if sys.version_info[0] == 3:
if isinstance(text, str):
return text
if isinstance(text, bytes):
return text.decode("utf-8", "ignore")
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
if sys.version_info[0] == 2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
if isinstance(text, unicode):
return text
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
raise ValueError("Only supported when running on Python2 or Python3.")
def load_vocab_file(vocab_file):
"""
Loads a vocabulary file and turns into a {token:id} dictionary.
"""
vocab_dict = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as vocab:
while True:
token = convert_to_unicode(vocab.readline())
if not token:
break
token = token.strip()
vocab_dict[token] = index
index += 1
return vocab_dict
def convert_by_vocab_dict(vocab_dict, items):
"""
Converts a sequence of [tokens|ids] according to the vocab dict.
"""
output = []
for item in items:
if item in vocab_dict:
output.append(vocab_dict[item])
else:
output.append(vocab_dict["<unk>"])
return output
class WhiteSpaceTokenizer():
"""
Whitespace tokenizer.
"""
def __init__(self, vocab_file):
self.vocab_dict = load_vocab_file(vocab_file)
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()}
def _is_whitespace_char(self, char):
"""
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here).
"""
if char in (" ", "\t", "\n", "\r"):
return True
uni = unicodedata.category(char)
if uni == "Zs":
return True
return False
def _is_control_char(self, char):
"""
Checks if it is a control character.
"""
if char in ("\t", "\n", "\r"):
return False
uni = unicodedata.category(char)
if uni in ("Cc", "Cf"):
return True
return False
def _clean_text(self, text):
"""
Remove invalid characters and cleanup whitespace.
"""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or self._is_control_char(char):
continue
if self._is_whitespace_char(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def _whitespace_tokenize(self, text):
"""
Clean whitespace and split text into tokens.
"""
text = text.strip()
text = text.lower()
if text.endswith("."):
text = text.replace(".", " .")
if not text:
tokens = []
else:
tokens = text.split()
return tokens
def tokenize(self, text):
"""
Tokenizes text.
"""
text = convert_to_unicode(text)
text = self._clean_text(text)
tokens = self._whitespace_tokenize(text)
return tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab_dict(self.vocab_dict, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab_dict(self.inv_vocab_dict, ids)

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""weight init"""
import math
import numpy as np
from mindspore import Tensor, Parameter
def gru_default_state(batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
'''Weight init for gru cell'''
stdv = 1 / math.sqrt(hidden_size)
weight_i = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), name='weight_i')
weight_h = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (hidden_size, 3*hidden_size)).astype(np.float32)), name='weight_h')
bias_i = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_i')
bias_h = Parameter(Tensor(
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_h')
init_h = Tensor(np.zeros((batch_size, hidden_size)).astype(np.float16))
return weight_i, weight_h, bias_i, bias_h, init_h
def dense_default_state(in_channel, out_channel):
'''Weight init for dense cell'''
stdv = 1 / math.sqrt(in_channel)
weight = Tensor(np.random.uniform(-stdv, stdv, (out_channel, in_channel)).astype(np.float32))
bias = Tensor(np.random.uniform(-stdv, stdv, (out_channel)).astype(np.float32))
return weight, bias

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train script"""
import os
import time
import argparse
import ast
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model
from mindspore.common import set_seed
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn.optim import Adam
from src.config import config
from src.seq2seq import Seq2Seq
from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell
from src.dataset import create_gru_dataset
from src.lr_schedule import dynamic_lr
set_seed(1)
parser = argparse.ArgumentParser(description="GRU training")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False)
def get_ms_timestamp():
t = time.time()
return int(round(t * 1000))
time_stamp_init = False
time_stamp_first = 0
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_id=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = get_ms_timestamp()
time_stamp_init = True
def step_end(self, run_context):
"""Monitor the loss in training."""
global time_stamp_first
time_stamp_current = get_ms_timestamp()
cb_params = run_context.original_args()
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs)))
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
f.write("time: {}, epoch: {}, step: {}, loss: {}, overflow: {}, loss_scale: {}".format(
time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs[0].asnumpy()),
str(cb_params.net_outputs[1].asnumpy()),
str(cb_params.net_outputs[2].asnumpy())))
f.write('\n')
if __name__ == '__main__':
if args.run_distribute:
rank = args.rank_id
device_num = args.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config)
network = GRUWithLossCell(network)
lr = dynamic_lr(config, dataset_size)
opt = Adam(network.trainable_params(), learning_rate=lr)
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale_value,
scale_factor=config.scale_factor,
scale_window=config.scale_window)
update_cell = scale_manager.get_update_cell()
netwithgrads = GRUTrainOneStepWithLossScaleCell(network, opt, update_cell)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
#Save Checkpoint
if config.save_checkpoint:
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank_id))
cb += [ckpt_cb]
netwithgrads.set_train(True)
model = Model(netwithgrads)
model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True)
Loading…
Cancel
Save