!8727 Clean code for gnmt_v2 network
From: @gaojing22 Reviewed-by: @yingjy,@guoqi1024 Signed-off-by: @yingjypull/8727/MERGE
commit
7689062c7d
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Train and eval api."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from config import GNMTConfig
|
||||||
|
from train import train_parallel
|
||||||
|
from src.gnmt_model import infer
|
||||||
|
from src.gnmt_model.bleu_calculate import bleu_calculate
|
||||||
|
from src.dataset.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='GNMT train and eval.')
|
||||||
|
# train
|
||||||
|
parser.add_argument("--config_train", type=str, required=True,
|
||||||
|
help="model config json file path.")
|
||||||
|
parser.add_argument("--dataset_schema_train", type=str, required=True,
|
||||||
|
help="dataset schema for train.")
|
||||||
|
parser.add_argument("--pre_train_dataset", type=str, required=True,
|
||||||
|
help="pre-train dataset address.")
|
||||||
|
# eval
|
||||||
|
parser.add_argument("--config_test", type=str, required=True,
|
||||||
|
help="model config json file path.")
|
||||||
|
parser.add_argument("--dataset_schema_test", type=str, required=True,
|
||||||
|
help="dataset schema for evaluation.")
|
||||||
|
parser.add_argument("--test_dataset", type=str, required=True,
|
||||||
|
help="test dataset address.")
|
||||||
|
parser.add_argument("--existed_ckpt", type=str, required=True,
|
||||||
|
help="existed checkpoint address.")
|
||||||
|
parser.add_argument("--vocab", type=str, required=True,
|
||||||
|
help="Vocabulary to use.")
|
||||||
|
parser.add_argument("--bpe_codes", type=str, required=True,
|
||||||
|
help="bpe codes to use.")
|
||||||
|
parser.add_argument("--test_tgt", type=str, required=True,
|
||||||
|
default=None,
|
||||||
|
help="data file of the test target")
|
||||||
|
parser.add_argument("--output", type=str, required=False,
|
||||||
|
default="./output.npz",
|
||||||
|
help="result file path.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config):
|
||||||
|
config = GNMTConfig.from_json_file(config)
|
||||||
|
config.compute_type = mstype.float16
|
||||||
|
config.dtype = mstype.float32
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _check_args(config):
|
||||||
|
if not os.path.exists(config):
|
||||||
|
raise FileNotFoundError("`config` is not existed.")
|
||||||
|
if not isinstance(config, str):
|
||||||
|
raise ValueError("`config` must be type of str.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
_rank_size = os.getenv('RANK_SIZE')
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
# train
|
||||||
|
_check_args(args.config_train)
|
||||||
|
_config_train = get_config(args.config_train)
|
||||||
|
_config_train.dataset_schema = args.dataset_schema_train
|
||||||
|
_config_train.pre_train_dataset = args.pre_train_dataset
|
||||||
|
set_seed(_config_train.random_seed)
|
||||||
|
assert _rank_size is not None and int(_rank_size) > 1
|
||||||
|
if _rank_size is not None and int(_rank_size) > 1:
|
||||||
|
train_parallel(_config_train)
|
||||||
|
# eval
|
||||||
|
_check_args(args.config_test)
|
||||||
|
_config_test = get_config(args.config_test)
|
||||||
|
_config_test.dataset_schema = args.dataset_schema_test
|
||||||
|
_config_test.test_dataset = args.test_dataset
|
||||||
|
_config_test.existed_ckpt = args.existed_ckpt
|
||||||
|
result = infer(_config_test)
|
||||||
|
|
||||||
|
with open(args.output, "wb") as f:
|
||||||
|
pickle.dump(result, f, 1)
|
||||||
|
|
||||||
|
result_npy_addr = args.output
|
||||||
|
vocab = args.vocab
|
||||||
|
bpe_codes = args.bpe_codes
|
||||||
|
test_tgt = args.test_tgt
|
||||||
|
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
|
||||||
|
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||||
|
print(f"BLEU scores is :{scores}")
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
cost_time = (end_time - start_time).seconds
|
||||||
|
print(f"Cost time is {cost_time}s.")
|
||||||
|
assert scores >= 23.8
|
||||||
|
assert cost_time < 10800.0
|
||||||
|
print("----done!----")
|
@ -0,0 +1,86 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "sh run_distributed_train_ascend.sh \
|
||||||
|
GNMT_ADDR RANK_TABLE_ADDR \
|
||||||
|
DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET \
|
||||||
|
DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
|
||||||
|
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
|
||||||
|
echo "for example:"
|
||||||
|
echo "sh run_distributed_train_ascend.sh \
|
||||||
|
/home/workspace/gnmt_v2 \
|
||||||
|
/home/workspace/rank_table_8p.json \
|
||||||
|
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
|
||||||
|
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001 \
|
||||||
|
/home/workspace/dataset_menu/newstest2014.en.json \
|
||||||
|
/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \
|
||||||
|
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
|
||||||
|
/home/workspace/wmt16_de_en/vocab.bpe.32000 \
|
||||||
|
/home/workspace/wmt16_de_en/bpe.32000 \
|
||||||
|
/home/workspace/wmt16_de_en/newstest2014.de"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
GNMT_ADDR=$1
|
||||||
|
RANK_TABLE_ADDR=$2
|
||||||
|
# train dataset addr
|
||||||
|
DATASET_SCHEMA_TRAIN=$3
|
||||||
|
PRE_TRAIN_DATASET=$4
|
||||||
|
# eval dataset addr
|
||||||
|
DATASET_SCHEMA_TEST=$5
|
||||||
|
TEST_DATASET=$6
|
||||||
|
EXISTED_CKPT_PATH=$7
|
||||||
|
VOCAB_ADDR=$8
|
||||||
|
BPE_CODE_ADDR=$9
|
||||||
|
TEST_TARGET=${10}
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
|
||||||
|
|
||||||
|
echo $RANK_TABLE_FILE
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export GLOG_v=2
|
||||||
|
|
||||||
|
for((i=0;i<=7;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${current_exec_path}/device$i
|
||||||
|
mkdir ${current_exec_path}/device$i
|
||||||
|
cd ${current_exec_path}/device$i || exit
|
||||||
|
cp ${current_exec_path}/*.py .
|
||||||
|
cp ${GNMT_ADDR}/*.py .
|
||||||
|
cp -r ${GNMT_ADDR}/src .
|
||||||
|
cp -r ${GNMT_ADDR}/config .
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python test_gnmt_v2.py \
|
||||||
|
--config_train=${GNMT_ADDR}/config/config.json \
|
||||||
|
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
|
||||||
|
--pre_train_dataset=$PRE_TRAIN_DATASET \
|
||||||
|
--config_test=${GNMT_ADDR}/config/config_test.json \
|
||||||
|
--dataset_schema_test=$DATASET_SCHEMA_TEST \
|
||||||
|
--test_dataset=$TEST_DATASET \
|
||||||
|
--existed_ckpt=$EXISTED_CKPT_PATH \
|
||||||
|
--vocab=$VOCAB_ADDR \
|
||||||
|
--bpe_codes=$BPE_CODE_ADDR \
|
||||||
|
--test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 &
|
||||||
|
cd ${current_exec_path} || exit
|
||||||
|
done
|
||||||
|
cd ${current_exec_path} || exit
|
Loading…
Reference in new issue