commit
0beb38fdb0
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,83 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer evaluation script."""
|
||||
|
||||
import argparse
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_infer import GRUInferCell
|
||||
from src.config import config
|
||||
|
||||
def run_gru_eval():
|
||||
"""
|
||||
Transformer evaluation.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='GRU eval')
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
|
||||
parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Dataset path, default: f`sns.")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
|
||||
device_id=args.device_id, save_graphs=False)
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
|
||||
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config, is_training=False)
|
||||
network = GRUInferCell(network)
|
||||
network.set_train(False)
|
||||
if args.ckpt_file != "":
|
||||
parameter_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, parameter_dict)
|
||||
model = Model(network)
|
||||
|
||||
predictions = []
|
||||
source_sents = []
|
||||
target_sents = []
|
||||
eval_text_len = 0
|
||||
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
source_sents.append(batch["source_ids"])
|
||||
target_sents.append(batch["target_ids"])
|
||||
source_ids = Tensor(batch["source_ids"], mstype.int32)
|
||||
target_ids = Tensor(batch["target_ids"], mstype.int32)
|
||||
predicted_ids = model.predict(source_ids, target_ids)
|
||||
print("predicts is ", predicted_ids.asnumpy())
|
||||
print("target_ids is ", target_ids)
|
||||
predictions.append(predicted_ids.asnumpy())
|
||||
eval_text_len = eval_text_len + 1
|
||||
|
||||
f_output = open(config.output_file, 'w')
|
||||
f_target = open(config.target_file, "w")
|
||||
for batch_out, true_sentence in zip(predictions, target_sents):
|
||||
for i in range(config.eval_batch_size):
|
||||
target_ids = [str(x) for x in true_sentence[i].tolist()]
|
||||
f_target.write(" ".join(target_ids) + "\n")
|
||||
token_ids = [str(x) for x in batch_out[i].tolist()]
|
||||
f_output.write(" ".join(token_ids) + "\n")
|
||||
f_output.close()
|
||||
f_target.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_gru_eval()
|
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh create_dataset.sh DATASET_PATH OUTPUT_PATH"
|
||||
echo "for example: sh create_dataset.sh /path/multi30k/ /path/multi30k/mindrecord/"
|
||||
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ulimit -u unlimited
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not valid"
|
||||
exit 1
|
||||
fi
|
||||
OUTPUT_PATH=$(get_real_path $2)
|
||||
echo $OUTPUT_PATH
|
||||
if [ ! -d $OUTPUT_PATH ]
|
||||
then
|
||||
echo "error: OUTPUT_PATH=$OUTPUT_PATH is not valid"
|
||||
exit 1
|
||||
fi
|
||||
paste $DATASET_PATH/train.de.tok $DATASET_PATH/train.en.tok > $DATASET_PATH/train.all
|
||||
python ../src/create_data.py --input_file $DATASET_PATH/train.all --num_splits 8 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_train_mindrecord --max_seq_length 32 --bucket [32]
|
||||
paste $DATASET_PATH/test.de.tok $DATASET_PATH/test.en.tok > $DATASET_PATH/test.all
|
||||
python ../src/create_data.py --input_file $DATASET_PATH/test.all --num_splits 1 --src_vocab_file $DATASET_PATH/vocab.de --trg_vocab_file $DATASET_PATH/vocab.en --output_file $OUTPUT_PATH/multi30k_test_mindrecord --max_seq_length 32 --bucket [32]
|
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE"
|
||||
echo "for example: sh parse_output.sh target.txt output.txt vocab.en"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ref_data=$1
|
||||
eval_output=$2
|
||||
vocab_file=$3
|
||||
|
||||
cat $ref_data \
|
||||
| python ../src/parse_output.py --vocab_file $vocab_file \
|
||||
| sed 's/@@ //g' > ${ref_data}.forbleu
|
||||
|
||||
cat $eval_output \
|
||||
| python ../src/parse_output.py --vocab_file $vocab_file \
|
||||
| sed 's/@@ //g' > ${eval_output}.forbleu
|
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATASET_DIR=$1
|
||||
python ../src/preprocess.py --dataset_path=$DATASET_DIR
|
@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [CKPT_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
CKPT_FILE=$(get_real_path $1)
|
||||
echo $CKPT_FILE
|
||||
if [ ! -f $CKPT_FILE ]
|
||||
then
|
||||
echo "error: CKPT_FILE=$CKPT_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --ckpt_file=$CKPT_FILE --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=4
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
if [ ! -f $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log &
|
||||
cd ..
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU config"""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"batch_size": 16,
|
||||
"eval_batch_size": 1,
|
||||
"src_vocab_size": 8154,
|
||||
"trg_vocab_size": 6113,
|
||||
"encoder_embedding_size": 256,
|
||||
"decoder_embedding_size": 256,
|
||||
"hidden_size": 512,
|
||||
"max_length": 32,
|
||||
"num_epochs": 30,
|
||||
"save_checkpoint": True,
|
||||
"ckpt_epoch": 10,
|
||||
"target_file": "target.txt",
|
||||
"output_file": "output.txt",
|
||||
"keep_checkpoint_max": 30,
|
||||
"base_lr": 0.001,
|
||||
"warmup_step": 300,
|
||||
"momentum": 0.9,
|
||||
"init_loss_scale_value": 1024,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 2000,
|
||||
"warmup_ratio": 1/3.0,
|
||||
"teacher_force_ratio": 0.5
|
||||
})
|
@ -0,0 +1,196 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create training instances for Transformer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
import tokenization
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
class SampleInstance():
|
||||
"""A single sample instance (sentence pair)."""
|
||||
|
||||
def __init__(self, source_tokens, target_tokens):
|
||||
self.source_tokens = source_tokens
|
||||
self.target_tokens = target_tokens
|
||||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
s += "source_tokens: %s\n" % (" ".join(
|
||||
[tokenization.convert_to_printable(x) for x in self.source_tokens]))
|
||||
s += "target tokens: %s\n" % (" ".join(
|
||||
[tokenization.convert_to_printable(x) for x in self.target_tokens]))
|
||||
s += "\n"
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def get_instance_features(instance, tokenizer_src, tokenizer_trg, max_seq_length, bucket):
|
||||
"""Get features from `SampleInstance`s."""
|
||||
def _find_bucket_length(source_tokens, target_tokens):
|
||||
source_ids = tokenizer_src.convert_tokens_to_ids(source_tokens)
|
||||
target_ids = tokenizer_trg.convert_tokens_to_ids(target_tokens)
|
||||
num = max(len(source_ids), len(target_ids))
|
||||
assert num <= bucket[-1]
|
||||
for index in range(1, len(bucket)):
|
||||
if bucket[index - 1] < num <= bucket[index]:
|
||||
return bucket[index]
|
||||
return bucket[0]
|
||||
|
||||
def _convert_ids_and_mask(tokenizer, input_tokens, seq_max_bucket_length):
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
assert len(input_ids) <= max_seq_length
|
||||
|
||||
while len(input_ids) < seq_max_bucket_length:
|
||||
input_ids.append(1)
|
||||
input_mask.append(0)
|
||||
|
||||
assert len(input_ids) == seq_max_bucket_length
|
||||
assert len(input_mask) == seq_max_bucket_length
|
||||
|
||||
return input_ids, input_mask
|
||||
|
||||
seq_max_bucket_length = _find_bucket_length(instance.source_tokens, instance.target_tokens)
|
||||
source_ids, source_mask = _convert_ids_and_mask(tokenizer_src, instance.source_tokens, seq_max_bucket_length)
|
||||
target_ids, target_mask = _convert_ids_and_mask(tokenizer_trg, instance.target_tokens, seq_max_bucket_length)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["source_ids"] = np.asarray(source_ids)
|
||||
features["source_mask"] = np.asarray(source_mask)
|
||||
features["target_ids"] = np.asarray(target_ids)
|
||||
features["target_mask"] = np.asarray(target_mask)
|
||||
|
||||
return features, seq_max_bucket_length
|
||||
|
||||
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
|
||||
"""Creates `SampleInstance`s for a single sentence pair."""
|
||||
EOS = "<eos>"
|
||||
SOS = "<sos>"
|
||||
|
||||
if len(source_words) >= max_seq_length-1 or len(target_words) >= max_seq_length-1:
|
||||
if clip_to_max_len:
|
||||
source_words = source_words[:min([len(source_words, max_seq_length-2)])]
|
||||
target_words = target_words[:min([len(target_words, max_seq_length-2)])]
|
||||
else:
|
||||
return None
|
||||
source_tokens = [SOS] + source_words + [EOS]
|
||||
target_tokens = [SOS] + target_words + [EOS]
|
||||
instance = SampleInstance(
|
||||
source_tokens=source_tokens,
|
||||
target_tokens=target_tokens)
|
||||
return instance
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True,
|
||||
help='Input raw text file (or comma-separated list of files).')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
parser.add_argument("--num_splits", type=int, default=16,
|
||||
help='The MindRecord file will be split into the number of partition.')
|
||||
parser.add_argument("--src_vocab_file", type=str, required=True,
|
||||
help='The vocabulary file that the Transformer model was trained on.')
|
||||
parser.add_argument("--trg_vocab_file", type=str, required=True,
|
||||
help='The vocabulary file that the Transformer model was trained on.')
|
||||
parser.add_argument("--clip_to_max_len", type=ast.literal_eval, default=False,
|
||||
help='clip sequences to maximum sequence length.')
|
||||
parser.add_argument("--max_seq_length", type=int, default=32, help='Maximum sequence length.')
|
||||
parser.add_argument("--bucket", type=ast.literal_eval, default=[32],
|
||||
help='bucket sequence length')
|
||||
args = parser.parse_args()
|
||||
tokenizer_src = tokenization.WhiteSpaceTokenizer(vocab_file=args.src_vocab_file)
|
||||
tokenizer_trg = tokenization.WhiteSpaceTokenizer(vocab_file=args.trg_vocab_file)
|
||||
input_files = []
|
||||
for input_pattern in args.input_file.split(","):
|
||||
input_files.append(input_pattern)
|
||||
logging.info("*** Read from input files ***")
|
||||
output_file = args.output_file
|
||||
logging.info("*** Write to output files ***")
|
||||
logging.info(" %s", output_file)
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
feature_dict = {}
|
||||
for i in args.bucket:
|
||||
feature_dict[i] = []
|
||||
for input_file in input_files:
|
||||
logging.info("*** Reading from %s ***", input_file)
|
||||
with open(input_file, "r") as reader:
|
||||
while True:
|
||||
line = tokenization.convert_to_unicode(reader.readline())
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 100000 == 0:
|
||||
logging.info("Read %d ...", total_read)
|
||||
if line.strip() == "":
|
||||
continue
|
||||
source_line, target_line = line.strip().split("\t")
|
||||
source_tokens = tokenizer_src.tokenize(source_line)
|
||||
target_tokens = tokenizer_trg.tokenize(target_line)
|
||||
if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
|
||||
logging.info("ignore long sentence!")
|
||||
continue
|
||||
instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
|
||||
clip_to_max_len=args.clip_to_max_len)
|
||||
if instance is None:
|
||||
continue
|
||||
features, seq_max_bucket_length = get_instance_features(instance, tokenizer_src, tokenizer_trg,
|
||||
args.max_seq_length, args.bucket)
|
||||
for key in feature_dict:
|
||||
if key == seq_max_bucket_length:
|
||||
feature_dict[key].append(features)
|
||||
if total_read <= 10:
|
||||
logging.info("*** Example ***")
|
||||
logging.info("source tokens: %s", " ".join(
|
||||
[tokenization.convert_to_printable(x) for x in instance.source_tokens]))
|
||||
logging.info("target tokens: %s", " ".join(
|
||||
[tokenization.convert_to_printable(x) for x in instance.target_tokens]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
for i in args.bucket:
|
||||
if args.num_splits == 1:
|
||||
output_file_name = output_file + '_' + str(i)
|
||||
else:
|
||||
output_file_name = output_file + '_' + str(i) + '_'
|
||||
writer = FileWriter(output_file_name, args.num_splits)
|
||||
data_schema = {"source_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_mask": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "gru")
|
||||
features_ = feature_dict[i]
|
||||
logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
|
||||
for item in features_:
|
||||
writer.write_raw_data([item])
|
||||
total_written += 1
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data operations, will be used in train.py."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from src.config import config
|
||||
import numpy as np
|
||||
de.config.set_seed(1)
|
||||
|
||||
def random_teacher_force(source_ids, target_ids, target_mask):
|
||||
|
||||
teacher_force = np.random.random() < config.teacher_force_ratio
|
||||
teacher_force_array = np.array([teacher_force], dtype=bool)
|
||||
return source_ids, target_ids, teacher_force_array
|
||||
|
||||
def create_gru_dataset(epoch_count=1, batch_size=1, rank_size=1, rank_id=0, do_shuffle=True, dataset_path=None,
|
||||
is_training=True):
|
||||
"""create dataset"""
|
||||
ds = de.MindDataset(dataset_path,
|
||||
columns_list=["source_ids", "target_ids",
|
||||
"target_mask"],
|
||||
shuffle=do_shuffle, num_parallel_workers=10, num_shards=rank_size, shard_id=rank_id)
|
||||
operations = random_teacher_force
|
||||
ds = ds.map(operations=operations, input_columns=["source_ids", "target_ids", "target_mask"],
|
||||
output_columns=["source_ids", "target_ids", "teacher_force"],
|
||||
column_order=["source_ids", "target_ids", "teacher_force"])
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
type_cast_op_bool = deC.TypeCast(mstype.bool_)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="source_ids")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="target_ids")
|
||||
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
return ds
|
@ -0,0 +1,104 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU cell"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.weight_init import gru_default_state
|
||||
|
||||
class BidirectionGRU(nn.Cell):
|
||||
'''
|
||||
BidirectionGRU model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(BidirectionGRU, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = gru_default_state(self.batch_size,
|
||||
self.embedding_size,
|
||||
self.hidden_size)
|
||||
self.weight_bw_i, self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, self.init_bw_h = \
|
||||
gru_default_state(self.batch_size, self.embedding_size, self.hidden_size)
|
||||
self.reverse = P.ReverseV2(axis=[1])
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.text_len = config.max_length
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
BidirectionGRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): BidirectionGRU input
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
y1, _, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
|
||||
bw_x = self.reverse(x)
|
||||
y1_bw, _, _, _, _, _ = self.rnn(bw_x, self.weight_bw_i,
|
||||
self.weight_bw_h, self.bias_bw_i, self.bias_bw_h, None, self.init_bw_h)
|
||||
y1_bw = self.reverse(y1_bw)
|
||||
output = self.concat((y1, y1_bw))
|
||||
hidden = self.concat((y1[self.text_len-1:self.text_len:1, ::, ::],
|
||||
y1_bw[self.text_len-1:self.text_len:1, ::, ::]))
|
||||
hidden = self.squeeze(hidden)
|
||||
return output, hidden
|
||||
|
||||
class GRU(nn.Cell):
|
||||
'''
|
||||
GRU model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(GRU, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h, self.init_h = \
|
||||
gru_default_state(self.batch_size, self.embedding_size + self.hidden_size*2, self.hidden_size)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
GRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): GRU input
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, self.init_h)
|
||||
return y1, h1
|
@ -0,0 +1,42 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU Infer cell"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.config import config
|
||||
|
||||
class GRUInferCell(nn.Cell):
|
||||
'''
|
||||
GRU infer consturction
|
||||
|
||||
Args:
|
||||
network: gru network
|
||||
'''
|
||||
def __init__(self, network):
|
||||
super(GRUInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.argmax = P.ArgMaxWithValue(axis=2)
|
||||
self.transpose = P.Transpose()
|
||||
self.teacher_force = Tensor(np.zeros((config.eval_batch_size)), mstype.bool_)
|
||||
def construct(self,
|
||||
encoder_inputs,
|
||||
decoder_inputs):
|
||||
predict_probs = self.network(encoder_inputs, decoder_inputs, self.teacher_force)
|
||||
predict_probs = self.transpose(predict_probs, (1, 0, 2))
|
||||
predict_ids, _ = self.argmax(predict_probs)
|
||||
return predict_ids
|
@ -0,0 +1,243 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""GRU train cell"""
|
||||
from mindspore import Tensor, Parameter, ParameterTuple, context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from src.config import config
|
||||
from src.loss import NLLLoss
|
||||
|
||||
class GRUWithLossCell(nn.Cell):
|
||||
"""
|
||||
GRU network connect with loss function.
|
||||
|
||||
Args:
|
||||
network: The training network.
|
||||
|
||||
Returns:
|
||||
the output of loss function.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(GRUWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = NLLLoss()
|
||||
self.logits_shape = (-1, config.src_vocab_size)
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
self.mean = P.ReduceMean()
|
||||
self.text_len = config.max_length
|
||||
self.split = P.Split(axis=0, output_num=config.max_length-1)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.add = P.AddN()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
|
||||
'''
|
||||
GRU loss cell
|
||||
|
||||
Args:
|
||||
encoder_inputs(Tensor): encoder inputs
|
||||
decoder_inputs(Tensor): decoder inputs
|
||||
teacher_force(Tensor): teacher force flag
|
||||
|
||||
Returns:
|
||||
loss(scalar): loss output
|
||||
'''
|
||||
logits = self.network(encoder_inputs, decoder_inputs, teacher_force)
|
||||
logits = self.cast(logits, mstype.float32)
|
||||
loss_total = ()
|
||||
decoder_targets = decoder_inputs
|
||||
decoder_output = logits
|
||||
for i in range(1, self.text_len):
|
||||
loss = self.loss(self.squeeze(decoder_output[i-1:i:1, ::, ::]), decoder_targets[:, i])
|
||||
loss_total += (loss,)
|
||||
loss = self.add(loss_total) / self.text_len
|
||||
return loss
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
class ClipGradients(nn.Cell):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Args:
|
||||
grads (list): List of gradient tuples.
|
||||
clip_type (Tensor): The way to clip, 'value' or 'norm'.
|
||||
clip_value (Tensor): Specifies how much to clip.
|
||||
|
||||
Returns:
|
||||
List, a list of clipped_grad tuples.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClipGradients, self).__init__()
|
||||
self.clip_by_norm = nn.ClipByNorm()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
def construct(self,
|
||||
grads,
|
||||
clip_type,
|
||||
clip_value):
|
||||
"""Defines the gradients clip."""
|
||||
if clip_type not in (0, 1):
|
||||
return grads
|
||||
new_grads = ()
|
||||
for grad in grads:
|
||||
dt = self.dtype(grad)
|
||||
if clip_type == 0:
|
||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
t = self.cast(t, dt)
|
||||
new_grads = new_grads + (t,)
|
||||
return new_grads
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class GRUTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of GRU network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(GRUTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
self.reshape = P.Reshape()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_before_grad(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
grads = self.grad(self.network, weights)(encoder_inputs,
|
||||
decoder_inputs,
|
||||
teacher_force,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
if not self.gpu_target:
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0: not overflow, >0: overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
# convert flag_sum to scalar
|
||||
flag_sum = self.reshape(flag_sum, (()))
|
||||
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
@ -0,0 +1,32 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NLLLoss cell"""
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
class NLLLoss(_Loss):
|
||||
'''
|
||||
NLLLoss function
|
||||
'''
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, label):
|
||||
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
|
||||
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
|
||||
return self.get_loss(loss)
|
@ -0,0 +1,45 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for deeptext"""
|
||||
import math
|
||||
|
||||
def rsqrt_decay(warmup_steps, current_step):
|
||||
return float(max([current_step, warmup_steps])) ** -0.5
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, total_steps):
|
||||
decay_steps = total_steps - warmup_steps
|
||||
linear_decay = (total_steps - current_step) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * current_step / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
learning_rate = decayed * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.base_lr
|
||||
total_steps = int(base_step * config.num_epochs)
|
||||
warmup_steps = int(config.warmup_step)
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ids to tokens."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tokenization
|
||||
|
||||
# Explicitly set the encoding
|
||||
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="recore nbest with smoothed sentence-level bleu.")
|
||||
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
||||
|
||||
for line in sys.stdin:
|
||||
token_ids = [int(x) for x in line.strip().split()]
|
||||
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
||||
sent = " ".join(tokens)
|
||||
sent = sent.split("<sos>")[-1]
|
||||
sent = sent.split("<eos>")[0]
|
||||
print(sent.strip())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''Dataset preprocess'''
|
||||
import os
|
||||
import argparse
|
||||
from collections import Counter
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
def create_tokenized_sentences(input_files, language):
|
||||
'''
|
||||
Create tokenized sentences files.
|
||||
|
||||
Args:
|
||||
input_files: input files.
|
||||
language: text language
|
||||
'''
|
||||
sentence = []
|
||||
total_lines = open(input_files, "r").read().splitlines()
|
||||
for line in total_lines:
|
||||
line = line.strip('\r\n ')
|
||||
line = line.lower()
|
||||
tokenize_sentence = word_tokenize(line, language)
|
||||
str_sentence = " ".join(tokenize_sentence)
|
||||
sentence.append(str_sentence)
|
||||
tokenize_file = input_files + ".tok"
|
||||
f = open(tokenize_file, "w")
|
||||
for line in sentence:
|
||||
f.write(line)
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
def get_dataset_vocab(text_file, vocab_file):
|
||||
'''
|
||||
Create dataset vocab files.
|
||||
|
||||
Args:
|
||||
text_file: dataset text files.
|
||||
vocab_file: vocab file
|
||||
'''
|
||||
counter = Counter()
|
||||
text_lines = open(text_file, "r").read().splitlines()
|
||||
for line in text_lines:
|
||||
for word in line.strip('\r\n ').split(' '):
|
||||
if word:
|
||||
counter[word] += 1
|
||||
vocab = open(vocab_file, "w")
|
||||
basic_label = ["<unk>", "<pad>", "<sos>", "<eos>"]
|
||||
for label in basic_label:
|
||||
vocab.write(label + "\n")
|
||||
for key, f in sorted(counter.items(), key=lambda x: x[1], reverse=True):
|
||||
if f < 2:
|
||||
continue
|
||||
vocab.write(key + "\n")
|
||||
vocab.close()
|
||||
|
||||
def MergeText(root_dir, file_list, output_file):
|
||||
'''
|
||||
Merge text files together.
|
||||
|
||||
Args:
|
||||
root_dir: root dir
|
||||
file_list: dataset files list.
|
||||
output_file: output file after merge
|
||||
'''
|
||||
output_file = os.path.join(root_dir, output_file)
|
||||
f_output = open(output_file, "w")
|
||||
for file_name in file_list:
|
||||
text_path = os.path.join(root_dir, file_name) + ".tok"
|
||||
f = open(text_path)
|
||||
f_output.write(f.read() + "\n")
|
||||
f_output.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='gru_dataset')
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="Dataset path, default: f`sns.")
|
||||
args = parser.parse_args()
|
||||
dataset_path = args.dataset_path
|
||||
src_file_list = ["train.de", "test.de", "val.de"]
|
||||
dst_file_list = ["train.en", "test.en", "val.en"]
|
||||
for file in src_file_list:
|
||||
file_path = os.path.join(dataset_path, file)
|
||||
create_tokenized_sentences(file_path, "english")
|
||||
for file in dst_file_list:
|
||||
file_path = os.path.join(dataset_path, file)
|
||||
create_tokenized_sentences(file_path, "german")
|
||||
src_all_file = "all.de.tok"
|
||||
dst_all_file = "all.en.tok"
|
||||
MergeText(dataset_path, src_file_list, src_all_file)
|
||||
MergeText(dataset_path, dst_file_list, dst_all_file)
|
||||
src_vocab = os.path.join(dataset_path, "vocab.de")
|
||||
dst_vocab = os.path.join(dataset_path, "vocab.en")
|
||||
get_dataset_vocab(os.path.join(dataset_path, src_all_file), src_vocab)
|
||||
get_dataset_vocab(os.path.join(dataset_path, dst_all_file), dst_vocab)
|
@ -0,0 +1,223 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Seq2Seq construction"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.gru import BidirectionGRU, GRU
|
||||
from src.weight_init import dense_default_state
|
||||
|
||||
class Attention(nn.Cell):
|
||||
'''
|
||||
Attention model
|
||||
'''
|
||||
def __init__(self, config):
|
||||
super(Attention, self).__init__()
|
||||
self.text_len = config.max_length
|
||||
self.attn = nn.Dense(in_channels=config.hidden_size * 3,
|
||||
out_channels=config.hidden_size).to_float(mstype.float16)
|
||||
self.fc = nn.Dense(config.hidden_size, 1, has_bias=False).to_float(mstype.float16)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.tanh = P.Tanh()
|
||||
self.softmax = P.Softmax()
|
||||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=2)
|
||||
self.cast = P.Cast()
|
||||
def construct(self, hidden, encoder_outputs):
|
||||
'''
|
||||
Attention construction
|
||||
|
||||
Args:
|
||||
hidden(Tensor): hidden state
|
||||
encoder_outputs(Tensor): the output of encoder
|
||||
|
||||
Returns:
|
||||
Tensor, attention output
|
||||
'''
|
||||
hidden = self.expandims(hidden, 1)
|
||||
hidden = self.tile(hidden, (1, self.text_len, 1))
|
||||
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
out = self.concat((hidden, encoder_outputs))
|
||||
out = self.attn(out)
|
||||
energy = self.tanh(out)
|
||||
attention = self.fc(energy)
|
||||
attention = self.squeeze(attention)
|
||||
attention = self.cast(attention, mstype.float32)
|
||||
attention = self.softmax(attention)
|
||||
attention = self.cast(attention, mstype.float16)
|
||||
return attention
|
||||
|
||||
class Encoder(nn.Cell):
|
||||
'''
|
||||
Encoder model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vocab_size = config.src_vocab_size
|
||||
self.embedding_size = config.encoder_embedding_size
|
||||
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
|
||||
self.rnn = BidirectionGRU(config, is_training=is_training).to_float(mstype.float16)
|
||||
self.fc = nn.Dense(2*self.hidden_size, self.hidden_size).to_float(mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.transpose = P.Transpose()
|
||||
self.p = P.Print()
|
||||
self.cast = P.Cast()
|
||||
self.text_len = config.max_length
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.tanh = P.Tanh()
|
||||
|
||||
def construct(self, src):
|
||||
'''
|
||||
Encoder construction
|
||||
|
||||
Args:
|
||||
src(Tensor): source sentences
|
||||
|
||||
Returns:
|
||||
output(Tensor): output of rnn
|
||||
hidden(Tensor): output hidden
|
||||
'''
|
||||
embedded = self.embedding(src)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
output, hidden = self.rnn(embedded)
|
||||
hidden = self.fc(hidden)
|
||||
hidden = self.tanh(hidden)
|
||||
return output, hidden
|
||||
|
||||
class Decoder(nn.Cell):
|
||||
'''
|
||||
Decoder model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Decoder, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vocab_size = config.trg_vocab_size
|
||||
self.embedding_size = config.decoder_embedding_size
|
||||
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
|
||||
self.rnn = GRU(config, is_training=is_training).to_float(mstype.float16)
|
||||
self.text_len = config.max_length
|
||||
self.shape = P.Shape()
|
||||
self.transpose = P.Transpose()
|
||||
self.p = P.Print()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
weight, bias = dense_default_state(self.embedding_size+self.hidden_size*3, self.vocab_size)
|
||||
self.fc = nn.Dense(self.embedding_size+self.hidden_size*3, self.vocab_size,
|
||||
weight_init=weight, bias_init=bias).to_float(mstype.float16)
|
||||
self.attention = Attention(config)
|
||||
self.bmm = P.BatchMatMul()
|
||||
self.dropout = nn.Dropout(0.7)
|
||||
self.expandims = P.ExpandDims()
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
'''
|
||||
Decoder construction
|
||||
|
||||
Args:
|
||||
inputs(Tensor): decoder input
|
||||
hidden(Tensor): hidden state
|
||||
encoder_outputs(Tensor): encoder output
|
||||
|
||||
Returns:
|
||||
pred_prob(Tensor): decoder predict probility
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
embedded = self.embedding(inputs)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
attn = self.attention(hidden, encoder_outputs)
|
||||
attn = self.expandims(attn, 1)
|
||||
encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
weight = self.bmm(attn, encoder_outputs)
|
||||
weight = self.transpose(weight, (1, 0, 2))
|
||||
emd_con = self.concat((embedded, weight))
|
||||
output, hidden = self.rnn(emd_con)
|
||||
out = self.concat((embedded, output, weight))
|
||||
out = self.squeeze(out)
|
||||
hidden = self.squeeze(hidden)
|
||||
prediction = self.fc(out)
|
||||
prediction = self.dropout(prediction)
|
||||
prediction = self.cast(prediction, mstype.float32)
|
||||
prediction = self.cast(prediction, mstype.float32)
|
||||
pred_prob = self.log_softmax(prediction)
|
||||
pred_prob = self.expandims(pred_prob, 0)
|
||||
return pred_prob, hidden
|
||||
|
||||
class Seq2Seq(nn.Cell):
|
||||
'''
|
||||
Seq2Seq model
|
||||
|
||||
Args:
|
||||
config: config of network
|
||||
'''
|
||||
def __init__(self, config, is_training=True):
|
||||
super(Seq2Seq, self).__init__()
|
||||
if is_training:
|
||||
self.batch_size = config.batch_size
|
||||
else:
|
||||
self.batch_size = config.eval_batch_size
|
||||
self.encoder = Encoder(config, is_training=is_training)
|
||||
self.decoder = Decoder(config, is_training=is_training)
|
||||
self.expandims = P.ExpandDims()
|
||||
self.dropout = nn.Dropout()
|
||||
self.shape = P.Shape()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.sos = Tensor(np.ones((self.batch_size, 1))*2, mstype.int32)
|
||||
self.select = P.Select()
|
||||
self.text_len = config.max_length
|
||||
|
||||
def construct(self, encoder_inputs, decoder_inputs, teacher_force):
|
||||
'''
|
||||
Seq2Seq construction
|
||||
|
||||
Args:
|
||||
encoder_inputs(Tensor): encoder input sentences
|
||||
decoder_inputs(Tensor): decoder input sentences
|
||||
teacher_force(Tensor): teacher force flag
|
||||
|
||||
Returns:
|
||||
outputs(Tensor): total predict probility
|
||||
'''
|
||||
decoder_input = self.sos
|
||||
encoder_output, hidden = self.encoder(encoder_inputs)
|
||||
decoder_hidden = hidden
|
||||
decoder_outputs = ()
|
||||
for i in range(1, self.text_len):
|
||||
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output)
|
||||
decoder_outputs += (decoder_output,)
|
||||
if self.training:
|
||||
decoder_input_force = decoder_inputs[::, i:i+1]
|
||||
decoder_input_top1, _ = self.argmax(self.squeeze(decoder_output))
|
||||
decoder_input = self.select(teacher_force, decoder_input_force, decoder_input_top1)
|
||||
else:
|
||||
decoder_input, _ = self.argmax(self.squeeze(decoder_output))
|
||||
outputs = self.concat(decoder_outputs)
|
||||
return outputs
|
@ -0,0 +1,155 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tokenization utilities."""
|
||||
|
||||
import sys
|
||||
import collections
|
||||
import unicodedata
|
||||
|
||||
def convert_to_printable(text):
|
||||
"""
|
||||
Converts `text` to a printable coding format.
|
||||
"""
|
||||
if sys.version_info[0] == 3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
||||
raise ValueError("Only supported when running on Python3.")
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""
|
||||
Converts `text` to Unicode format.
|
||||
"""
|
||||
if sys.version_info[0] == 3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
if isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
||||
if sys.version_info[0] == 2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
if isinstance(text, unicode):
|
||||
return text
|
||||
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
|
||||
raise ValueError("Only supported when running on Python2 or Python3.")
|
||||
|
||||
|
||||
def load_vocab_file(vocab_file):
|
||||
"""
|
||||
Loads a vocabulary file and turns into a {token:id} dictionary.
|
||||
"""
|
||||
vocab_dict = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as vocab:
|
||||
while True:
|
||||
token = convert_to_unicode(vocab.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab_dict[token] = index
|
||||
index += 1
|
||||
return vocab_dict
|
||||
|
||||
|
||||
def convert_by_vocab_dict(vocab_dict, items):
|
||||
"""
|
||||
Converts a sequence of [tokens|ids] according to the vocab dict.
|
||||
"""
|
||||
output = []
|
||||
for item in items:
|
||||
if item in vocab_dict:
|
||||
output.append(vocab_dict[item])
|
||||
else:
|
||||
output.append(vocab_dict["<unk>"])
|
||||
return output
|
||||
|
||||
|
||||
class WhiteSpaceTokenizer():
|
||||
"""
|
||||
Whitespace tokenizer.
|
||||
"""
|
||||
def __init__(self, vocab_file):
|
||||
self.vocab_dict = load_vocab_file(vocab_file)
|
||||
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()}
|
||||
|
||||
def _is_whitespace_char(self, char):
|
||||
"""
|
||||
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here).
|
||||
"""
|
||||
if char in (" ", "\t", "\n", "\r"):
|
||||
return True
|
||||
uni = unicodedata.category(char)
|
||||
if uni == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_control_char(self, char):
|
||||
"""
|
||||
Checks if it is a control character.
|
||||
"""
|
||||
if char in ("\t", "\n", "\r"):
|
||||
return False
|
||||
uni = unicodedata.category(char)
|
||||
if uni in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""
|
||||
Remove invalid characters and cleanup whitespace.
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or self._is_control_char(char):
|
||||
continue
|
||||
if self._is_whitespace_char(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _whitespace_tokenize(self, text):
|
||||
"""
|
||||
Clean whitespace and split text into tokens.
|
||||
"""
|
||||
text = text.strip()
|
||||
text = text.lower()
|
||||
if text.endswith("."):
|
||||
text = text.replace(".", " .")
|
||||
if not text:
|
||||
tokens = []
|
||||
else:
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Tokenizes text.
|
||||
"""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
tokens = self._whitespace_tokenize(text)
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab_dict(self.vocab_dict, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab_dict(self.inv_vocab_dict, ids)
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""weight init"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
def gru_default_state(batch_size, input_size, hidden_size, num_layers=1, bidirectional=False):
|
||||
'''Weight init for gru cell'''
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
weight_i = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), name='weight_i')
|
||||
weight_h = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (hidden_size, 3*hidden_size)).astype(np.float32)), name='weight_h')
|
||||
bias_i = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_i')
|
||||
bias_h = Parameter(Tensor(
|
||||
np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), name='bias_h')
|
||||
init_h = Tensor(np.zeros((batch_size, hidden_size)).astype(np.float16))
|
||||
return weight_i, weight_h, bias_i, bias_h, init_h
|
||||
|
||||
def dense_default_state(in_channel, out_channel):
|
||||
'''Weight init for dense cell'''
|
||||
stdv = 1 / math.sqrt(in_channel)
|
||||
weight = Tensor(np.random.uniform(-stdv, stdv, (out_channel, in_channel)).astype(np.float32))
|
||||
bias = Tensor(np.random.uniform(-stdv, stdv, (out_channel)).astype(np.float32))
|
||||
return weight, bias
|
@ -0,0 +1,130 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train script"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.nn.optim import Adam
|
||||
from src.config import config
|
||||
from src.seq2seq import Seq2Seq
|
||||
from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell
|
||||
from src.dataset import create_gru_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="GRU training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False)
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.rank_id = rank_id
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = get_ms_timestamp()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
global time_stamp_first
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, loss: {}, overflow: {}, loss_scale: {}".format(
|
||||
time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs[0].asnumpy()),
|
||||
str(cb_params.net_outputs[1].asnumpy()),
|
||||
str(cb_params.net_outputs[2].asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.run_distribute:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
|
||||
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("dataset size is {}".format(dataset_size))
|
||||
network = Seq2Seq(config)
|
||||
network = GRUWithLossCell(network)
|
||||
lr = dynamic_lr(config, dataset_size)
|
||||
opt = Adam(network.trainable_params(), learning_rate=lr)
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale_value,
|
||||
scale_factor=config.scale_factor,
|
||||
scale_window=config.scale_window)
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
netwithgrads = GRUTrainOneStepWithLossScaleCell(network, opt, update_cell)
|
||||
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
#Save Checkpoint
|
||||
if config.save_checkpoint:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(args.rank_id))
|
||||
cb += [ckpt_cb]
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in new issue