!10290 add fasttext model to model zoo
From: @zhaojichen Reviewed-by: Signed-off-by:pull/10290/MERGE
commit
1150ae3376
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FastText for Evaluation"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
|
from mindspore import context
|
||||||
|
from src.fasttext_model import FastText
|
||||||
|
parser = argparse.ArgumentParser(description='fasttext')
|
||||||
|
parser.add_argument('--data_path', type=str, help='infer dataset path..')
|
||||||
|
parser.add_argument('--data_name', type=str, required=True, default='ag',
|
||||||
|
help='dataset name. eg. ag, dbpedia')
|
||||||
|
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||||
|
help="existed checkpoint address.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.data_name == "ag":
|
||||||
|
from src.config import config_ag as config
|
||||||
|
target_label1 = ['0', '1', '2', '3']
|
||||||
|
elif args.data_name == 'dbpedia':
|
||||||
|
from src.config import config_db as config
|
||||||
|
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
||||||
|
elif args.data_name == 'yelp_p':
|
||||||
|
from src.config import config_yelpp as config
|
||||||
|
target_label1 = ['0', '1']
|
||||||
|
context.set_context(
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
save_graphs=False,
|
||||||
|
device_target="Ascend")
|
||||||
|
|
||||||
|
class FastTextInferCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of FastText network infer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (nn.Cell): FastText model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor, Tensor], predicted_ids
|
||||||
|
"""
|
||||||
|
def __init__(self, network):
|
||||||
|
super(FastTextInferCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||||
|
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||||
|
|
||||||
|
def construct(self, src_tokens, src_tokens_lengths):
|
||||||
|
"""construct fasttext infer cell"""
|
||||||
|
prediction = self.network(src_tokens, src_tokens_lengths)
|
||||||
|
predicted_idx = self.log_softmax(prediction)
|
||||||
|
predicted_idx, _ = self.argmax(predicted_idx)
|
||||||
|
|
||||||
|
return predicted_idx
|
||||||
|
|
||||||
|
def load_infer_dataset(batch_size, datafile):
|
||||||
|
"""data loader for infer"""
|
||||||
|
ds = de.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
||||||
|
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||||
|
ds = ds.batch(batch_size=batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def run_fasttext_infer():
|
||||||
|
"""run infer with FastText"""
|
||||||
|
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
|
||||||
|
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||||
|
|
||||||
|
parameter_dict = load_checkpoint(args.model_ckpt)
|
||||||
|
load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
|
||||||
|
|
||||||
|
ft_infer = FastTextInferCell(fasttext_model)
|
||||||
|
|
||||||
|
model = Model(ft_infer)
|
||||||
|
|
||||||
|
predictions = []
|
||||||
|
target_sens = []
|
||||||
|
|
||||||
|
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||||
|
target_sens.append(batch['label_idx'])
|
||||||
|
src_tokens = Tensor(batch['src_tokens'], mstype.int32)
|
||||||
|
src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32)
|
||||||
|
predicted_idx = model.predict(src_tokens, src_tokens_length)
|
||||||
|
predictions.append(predicted_idx.asnumpy())
|
||||||
|
|
||||||
|
from sklearn.metrics import accuracy_score, classification_report
|
||||||
|
target_sens = np.array(target_sens).flatten()
|
||||||
|
predictions = np.array(predictions).flatten()
|
||||||
|
acc = accuracy_score(target_sens, predictions)
|
||||||
|
|
||||||
|
result_report = classification_report(target_sens, predictions, target_names=target_label1)
|
||||||
|
print("********Accuracy: ", acc)
|
||||||
|
print(result_report)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_fasttext_infer()
|
@ -0,0 +1,3 @@
|
|||||||
|
spacy
|
||||||
|
sklearn
|
||||||
|
en_core_web_lg
|
@ -0,0 +1,83 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "sh create_dataset.sh SOURCE_DATASET_PATH DATASET_NAME"
|
||||||
|
echo "for example: sh create_dataset.sh /home/workspace/ag_news_csv ag"
|
||||||
|
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
ulimit -u unlimited
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
SOURCE_DATASET_PATH=$(get_real_path $1)
|
||||||
|
DATASET_NAME=$2
|
||||||
|
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=5
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == 'ag' ];
|
||||||
|
then
|
||||||
|
echo "Begin to process ag news data"
|
||||||
|
if [ -d "ag" ];
|
||||||
|
then
|
||||||
|
rm -rf ./ag
|
||||||
|
fi
|
||||||
|
mkdir ./ag
|
||||||
|
cd ./ag || exit
|
||||||
|
echo "start data preprocess for device $DEVICE_ID"
|
||||||
|
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 4 --max_len 467 --bucket [64,128,467] --test_bucket [467]
|
||||||
|
cd ..
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == 'dbpedia' ];
|
||||||
|
then
|
||||||
|
echo "Begin to process dbpedia data"
|
||||||
|
if [ -d "dbpedia" ];
|
||||||
|
then
|
||||||
|
rm -rf ./dbpedia
|
||||||
|
fi
|
||||||
|
mkdir ./dbpedia
|
||||||
|
cd ./dbpedia || exit
|
||||||
|
echo "start data preprocess for device $DEVICE_ID"
|
||||||
|
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120]
|
||||||
|
cd ..
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $DATASET_NAME == 'yelp_p' ];
|
||||||
|
then
|
||||||
|
echo "Begin to process ag news data"
|
||||||
|
if [ -d "yelp_p" ];
|
||||||
|
then
|
||||||
|
rm -rf ./yelp_p
|
||||||
|
fi
|
||||||
|
mkdir ./yelp_p
|
||||||
|
cd ./yelp_p || exit
|
||||||
|
echo "start data preprocess for device $DEVICE_ID"
|
||||||
|
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955]
|
||||||
|
cd ..
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,66 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "sh run_distributed_train.sh DATASET_PATH RANK_TABLE_PATH"
|
||||||
|
echo "for example: sh run_distributed_train.sh /home/workspace/ag /home/workspace/rank_table_file.json"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET=$(get_real_path $1)
|
||||||
|
echo $DATASET
|
||||||
|
DATANAME=$(basename $DATASET)
|
||||||
|
RANK_TABLE_PATH=$(get_real_path $2)
|
||||||
|
echo $DATANAME
|
||||||
|
if [ ! -d $DATASET ]
|
||||||
|
then
|
||||||
|
echo "Error: DATA_PATH=$DATASET is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=$RANK_TABLE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
echo $RANK_TABLE_FILE
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
|
||||||
|
|
||||||
|
for((i=0;i<=7;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${current_exec_path}/device$i
|
||||||
|
mkdir ${current_exec_path}/device$i
|
||||||
|
cd ${current_exec_path}/device$i
|
||||||
|
cp ../../*.py ./
|
||||||
|
cp -r ../../src ./
|
||||||
|
cp -r ../*.sh ./
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||||
|
cd ${current_exec_path}
|
||||||
|
done
|
||||||
|
cd ${current_exec_path}
|
@ -0,0 +1,54 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "sh run_eval.sh DATASET_PATH DATASET_NAME MODEL_CKPT"
|
||||||
|
echo "for example: sh run_eval.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET=$(get_real_path $1)
|
||||||
|
echo $DATASET
|
||||||
|
DATANAME=$2
|
||||||
|
MODEL_CKPT=$(get_real_path $3)
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=5
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cp -r ../scripts/*.sh ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the scipt as: "
|
||||||
|
echo "sh run_standalone_train.sh DATASET_PATH"
|
||||||
|
echo "for example: sh run_standalone_train.sh /home/workspace/ag"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET=$(get_real_path $1)
|
||||||
|
echo $DATASET
|
||||||
|
DATANAME=$(basename $DATASET)
|
||||||
|
echo $DATANAME
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cp -r ../scripts/*.sh ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#" :===========================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
config_yelpp = ed({
|
||||||
|
'vocab_size': 6414979,
|
||||||
|
'buckets': [64, 128, 256, 512, 2955],
|
||||||
|
'batch_size': 128,
|
||||||
|
'embedding_dims': 16,
|
||||||
|
'num_class': 2,
|
||||||
|
'epoch': 5,
|
||||||
|
'lr': 0.02,
|
||||||
|
'min_lr': 1e-6,
|
||||||
|
'decay_steps': 549,
|
||||||
|
'warmup_steps': 400000,
|
||||||
|
'poly_lr_scheduler_power': 0.5,
|
||||||
|
'epoch_count': 1,
|
||||||
|
'pretrain_ckpt_dir': None,
|
||||||
|
'save_ckpt_steps': 549,
|
||||||
|
'keep_ckpt_max': 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
config_db = ed({
|
||||||
|
'vocab_size': 6596536,
|
||||||
|
'buckets': [128, 512, 3013],
|
||||||
|
'batch_size': 128,
|
||||||
|
'embedding_dims': 16,
|
||||||
|
'num_class': 14,
|
||||||
|
'epoch': 5,
|
||||||
|
'lr': 0.05,
|
||||||
|
'min_lr': 1e-6,
|
||||||
|
'decay_steps': 549,
|
||||||
|
'warmup_steps': 400000,
|
||||||
|
'poly_lr_scheduler_power': 0.5,
|
||||||
|
'epoch_count': 1,
|
||||||
|
'pretrain_ckpt_dir': None,
|
||||||
|
'save_ckpt_steps': 548,
|
||||||
|
'keep_ckpt_max': 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
config_ag = ed({
|
||||||
|
'vocab_size': 1383812,
|
||||||
|
'buckets': [64, 128, 467],
|
||||||
|
'batch_size': 128,
|
||||||
|
'embedding_dims': 16,
|
||||||
|
'num_class': 4,
|
||||||
|
'epoch': 5,
|
||||||
|
'lr': 0.05,
|
||||||
|
'min_lr': 1e-6,
|
||||||
|
'decay_steps': 115,
|
||||||
|
'warmup_steps': 400000,
|
||||||
|
'poly_lr_scheduler_power': 0.5,
|
||||||
|
'epoch_count': 1,
|
||||||
|
'pretrain_ckpt_dir': None,
|
||||||
|
'save_ckpt_steps': 116,
|
||||||
|
'keep_ckpt_max': 10,
|
||||||
|
})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FastText model."""
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.initializer import XavierUniform
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
class FastText(nn.Cell):
|
||||||
|
"""
|
||||||
|
FastText model
|
||||||
|
Args:
|
||||||
|
|
||||||
|
vocab_size: vocabulary size
|
||||||
|
embedding_dims: The size of each embedding vector
|
||||||
|
num_class: number of labels
|
||||||
|
"""
|
||||||
|
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||||
|
super(FastText, self).__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.embeding_dims = embedding_dims
|
||||||
|
self.num_class = num_class
|
||||||
|
self.embeding_func = nn.Embedding(vocab_size=self.vocab_size,
|
||||||
|
embedding_size=self.embeding_dims,
|
||||||
|
padding_idx=0, embedding_table='Zeros')
|
||||||
|
self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class,
|
||||||
|
weight_init=XavierUniform(1)).to_float(mstype.float16)
|
||||||
|
self.reducesum = P.ReduceSum()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(axis=1)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.realdiv = P.RealDiv()
|
||||||
|
self.fill = P.Fill()
|
||||||
|
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||||
|
def construct(self, src_tokens, src_token_length):
|
||||||
|
"""
|
||||||
|
construct network
|
||||||
|
Args:
|
||||||
|
|
||||||
|
src_tokens: source sentences
|
||||||
|
src_token_length: source sentences length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor], network outputs
|
||||||
|
"""
|
||||||
|
src_tokens = self.embeding_func(src_tokens)
|
||||||
|
embeding = self.reducesum(src_tokens, 1)
|
||||||
|
|
||||||
|
length_tiled = self.tile(src_token_length, (1, self.embeding_dims))
|
||||||
|
|
||||||
|
embeding = self.realdiv(embeding, length_tiled)
|
||||||
|
|
||||||
|
embeding = self.cast(embeding, mstype.float16)
|
||||||
|
classifer = self.fc(embeding)
|
||||||
|
classifer = self.cast(classifer, mstype.float32)
|
||||||
|
|
||||||
|
return classifer
|
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FastText for train"""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.communication.management import get_group_size
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
from mindspore import context
|
||||||
|
from src.fasttext_model import FastText
|
||||||
|
|
||||||
|
|
||||||
|
GRADIENT_CLIP_TYPE = 1
|
||||||
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
|
||||||
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||||
|
|
||||||
|
|
||||||
|
@clip_grad.register("Number", "Number", "Tensor")
|
||||||
|
def _clip_grad(clip_type, clip_value, grad):
|
||||||
|
"""
|
||||||
|
Clip gradients.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||||
|
clip_value (float): Specifies how much to clip.
|
||||||
|
grad (tuple[Tensor]): Gradients.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[Tensor], clipped gradients.
|
||||||
|
"""
|
||||||
|
if clip_type not in (0, 1):
|
||||||
|
return grad
|
||||||
|
dt = F.dtype(grad)
|
||||||
|
if clip_type == 0:
|
||||||
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
|
else:
|
||||||
|
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
class FastTextNetWithLoss(nn.Cell):
|
||||||
|
"""
|
||||||
|
Provide FastText training loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: vocabulary size
|
||||||
|
embedding_dims: The size of each embedding vector
|
||||||
|
num_class: number of labels
|
||||||
|
"""
|
||||||
|
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||||
|
super(FastTextNetWithLoss, self).__init__()
|
||||||
|
self.fasttext = FastText(vocab_size, embedding_dims, num_class)
|
||||||
|
self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
|
self.squeeze = P.Squeeze(axis=1)
|
||||||
|
self.print = P.Print()
|
||||||
|
|
||||||
|
def construct(self, src_tokens, src_tokens_lengths, label_idx):
|
||||||
|
"""
|
||||||
|
FastText network with loss.
|
||||||
|
"""
|
||||||
|
predict_score = self.fasttext(src_tokens, src_tokens_lengths)
|
||||||
|
label_idx = self.squeeze(label_idx)
|
||||||
|
predict_score = self.loss_func(predict_score, label_idx)
|
||||||
|
|
||||||
|
return predict_score
|
||||||
|
|
||||||
|
|
||||||
|
class FastTextTrainOneStepCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of fasttext network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
sens (Number): The adjust parameter. Default: 1.0.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||||
|
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = None
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = context.get_auto_parallel_context("gradients_mean")
|
||||||
|
degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def set_sens(self, value):
|
||||||
|
self.sens = value
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
src_token_text,
|
||||||
|
src_tokens_text_length,
|
||||||
|
label_idx_tag):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(src_token_text,
|
||||||
|
src_tokens_text_length,
|
||||||
|
label_idx_tag)
|
||||||
|
grads = self.grad(self.network, weights)(src_token_text,
|
||||||
|
src_tokens_text_length,
|
||||||
|
label_idx_tag,
|
||||||
|
self.cast(F.tuple_to_array((self.sens,)),
|
||||||
|
mstype.float32))
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
return F.depend(loss, succ)
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FastText data loader"""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
|
|
||||||
|
def load_dataset(dataset_path,
|
||||||
|
batch_size,
|
||||||
|
epoch_count=1,
|
||||||
|
rank_size=1,
|
||||||
|
rank_id=0,
|
||||||
|
bucket=None,
|
||||||
|
shuffle=True):
|
||||||
|
"""dataset loader"""
|
||||||
|
def batch_per_bucket(bucket_length, input_file):
|
||||||
|
input_file = input_file +'/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
|
||||||
|
if not input_file:
|
||||||
|
raise FileNotFoundError("input file parameter must not be empty.")
|
||||||
|
|
||||||
|
ds = de.MindDataset(input_file,
|
||||||
|
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||||
|
shuffle=shuffle,
|
||||||
|
num_shards=rank_size,
|
||||||
|
shard_id=rank_id,
|
||||||
|
num_parallel_workers=8)
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print(f"Dataset size: {ori_dataset_size}")
|
||||||
|
repeat_count = epoch_count
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||||
|
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||||
|
|
||||||
|
ds = ds.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||||
|
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=False)
|
||||||
|
ds = ds.repeat(repeat_count)
|
||||||
|
return ds
|
||||||
|
for i, _ in enumerate(bucket):
|
||||||
|
bucket_len = bucket[i]
|
||||||
|
ds_per = batch_per_bucket(bucket_len, dataset_path)
|
||||||
|
if i == 0:
|
||||||
|
ds = ds_per
|
||||||
|
else:
|
||||||
|
ds = ds + ds_per
|
||||||
|
ds = ds.shuffle(ds.get_dataset_size())
|
||||||
|
ds.channel_name = 'fasttext'
|
||||||
|
|
||||||
|
return ds
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Learning rate utilities."""
|
||||||
|
from math import ceil
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
|
||||||
|
"""
|
||||||
|
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr (float): Initial learning rate.
|
||||||
|
warmup_steps (int): Warmup steps.
|
||||||
|
decay_steps (int): Decay steps.
|
||||||
|
total_update_num (int): Total update steps.
|
||||||
|
min_lr (float): Min learning.
|
||||||
|
power (float): Power factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray, learning rate of each step.
|
||||||
|
"""
|
||||||
|
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||||
|
|
||||||
|
if decay_steps <= 0:
|
||||||
|
raise ValueError("`decay_steps` must larger than 1.")
|
||||||
|
|
||||||
|
_start_step = 0
|
||||||
|
if 0 < warmup_steps < total_update_num:
|
||||||
|
warmup_end_lr = lr
|
||||||
|
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||||
|
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||||
|
_start_step = warmup_steps
|
||||||
|
|
||||||
|
decay_steps = decay_steps
|
||||||
|
for step in range(_start_step, total_update_num):
|
||||||
|
_step = step - _start_step
|
||||||
|
ratio = ceil(_step / decay_steps)
|
||||||
|
ratio = 1 if ratio < 1 else ratio
|
||||||
|
_decay_steps = decay_steps * ratio
|
||||||
|
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
|
||||||
|
|
||||||
|
return lrs
|
@ -0,0 +1,202 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""FastText for train"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn.optim import Adam
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.callback import Callback, TimeMonitor
|
||||||
|
from mindspore.communication import management as MultiAscend
|
||||||
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from src.load_dataset import load_dataset
|
||||||
|
from src.lr_schedule import polynomial_decay_scheduler
|
||||||
|
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
|
||||||
|
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.data_name == "ag":
|
||||||
|
from src.config import config_ag as config
|
||||||
|
elif args.data_name == 'dbpedia':
|
||||||
|
from src.config import config_db as config
|
||||||
|
elif args.data_name == 'yelp_p':
|
||||||
|
from src.config import config_yelpp as config
|
||||||
|
|
||||||
|
def get_ms_timestamp():
|
||||||
|
t = time.time()
|
||||||
|
return int(round(t * 1000))
|
||||||
|
set_seed(5)
|
||||||
|
time_stamp_init = False
|
||||||
|
time_stamp_first = 0
|
||||||
|
rank_id = os.getenv('DEVICE_ID')
|
||||||
|
context.set_context(
|
||||||
|
mode=context.GRAPH_MODE,
|
||||||
|
save_graphs=False,
|
||||||
|
device_target="Ascend")
|
||||||
|
|
||||||
|
class LossCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
|
||||||
|
If the loss is NAN or INF terminating training.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
per_print_times (int): Print loss every times. Default: 1.
|
||||||
|
"""
|
||||||
|
def __init__(self, per_print_times=1, rank_ids=0):
|
||||||
|
super(LossCallBack, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0.")
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
self.rank_id = rank_ids
|
||||||
|
global time_stamp_init, time_stamp_first
|
||||||
|
if not time_stamp_init:
|
||||||
|
time_stamp_first = get_ms_timestamp()
|
||||||
|
time_stamp_init = True
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""Monitor the loss in training."""
|
||||||
|
global time_stamp_first
|
||||||
|
time_stamp_current = get_ms_timestamp()
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||||
|
cb_params.cur_epoch_num,
|
||||||
|
cb_params.cur_step_num,
|
||||||
|
str(cb_params.net_outputs)))
|
||||||
|
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||||
|
f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
|
||||||
|
time_stamp_current - time_stamp_first,
|
||||||
|
cb_params.cur_epoch_num,
|
||||||
|
cb_params.cur_step_num,
|
||||||
|
str(cb_params.net_outputs.asnumpy())))
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
def _build_training_pipeline(pre_dataset):
|
||||||
|
"""
|
||||||
|
Build training pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pre_dataset: preprocessed dataset
|
||||||
|
"""
|
||||||
|
net_with_loss = FastTextNetWithLoss(config.vocab_size, config.embedding_dims, config.num_class)
|
||||||
|
net_with_loss.init_parameters_data()
|
||||||
|
if config.pretrain_ckpt_dir:
|
||||||
|
parameter_dict = load_checkpoint(config.pretrain_ckpt_dir)
|
||||||
|
load_param_into_net(net_with_loss, parameter_dict)
|
||||||
|
if pre_dataset is None:
|
||||||
|
raise ValueError("pre-process dataset must be provided")
|
||||||
|
|
||||||
|
#get learning rate
|
||||||
|
update_steps = config.epoch * pre_dataset.get_dataset_size()
|
||||||
|
decay_steps = pre_dataset.get_dataset_size()
|
||||||
|
rank_size = os.getenv("RANK_SIZE")
|
||||||
|
if isinstance(rank_size, int):
|
||||||
|
raise ValueError("RANK_SIZE must be integer")
|
||||||
|
if rank_size is not None and int(rank_size) > 1:
|
||||||
|
base_lr = config.lr
|
||||||
|
else:
|
||||||
|
base_lr = config.lr / 10
|
||||||
|
print("+++++++++++Total update steps ", update_steps)
|
||||||
|
lr = Tensor(polynomial_decay_scheduler(lr=base_lr,
|
||||||
|
min_lr=config.min_lr,
|
||||||
|
decay_steps=decay_steps,
|
||||||
|
total_update_num=update_steps,
|
||||||
|
warmup_steps=config.warmup_steps,
|
||||||
|
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
|
||||||
|
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999)
|
||||||
|
|
||||||
|
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||||
|
net_with_grads.set_train(True)
|
||||||
|
model = Model(net_with_grads)
|
||||||
|
loss_monitor = LossCallBack(rank_ids=rank_id)
|
||||||
|
dataset_size = pre_dataset.get_dataset_size()
|
||||||
|
time_monitor = TimeMonitor(data_size=dataset_size)
|
||||||
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps,
|
||||||
|
keep_checkpoint_max=config.keep_ckpt_max)
|
||||||
|
callbacks = [time_monitor, loss_monitor]
|
||||||
|
if rank_size is None or int(rank_size) == 1:
|
||||||
|
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||||
|
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||||
|
config=ckpt_config)
|
||||||
|
callbacks.append(ckpt_callback)
|
||||||
|
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
|
||||||
|
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||||
|
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||||
|
config=ckpt_config)
|
||||||
|
callbacks.append(ckpt_callback)
|
||||||
|
print("Prepare to Training....")
|
||||||
|
epoch_size = pre_dataset.get_repeat_count()
|
||||||
|
print("Epoch size ", epoch_size)
|
||||||
|
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||||
|
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||||
|
model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def train_single(input_file_path):
|
||||||
|
"""
|
||||||
|
Train model on single device
|
||||||
|
Args:
|
||||||
|
input_file_path: preprocessed dataset path
|
||||||
|
"""
|
||||||
|
print("Staring training on single device.")
|
||||||
|
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
epoch_count=config.epoch_count,
|
||||||
|
bucket=config.buckets)
|
||||||
|
_build_training_pipeline(preprocessed_data)
|
||||||
|
|
||||||
|
|
||||||
|
def set_parallel_env():
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
MultiAscend.init()
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
device_num=MultiAscend.get_group_size(),
|
||||||
|
gradients_mean=True)
|
||||||
|
def train_paralle(input_file_path):
|
||||||
|
"""
|
||||||
|
Train model on multi device
|
||||||
|
Args:
|
||||||
|
input_file_path: preprocessed dataset path
|
||||||
|
"""
|
||||||
|
set_parallel_env()
|
||||||
|
print("Starting traning on mutiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
|
||||||
|
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
epoch_count=config.epoch_count,
|
||||||
|
rank_size=MultiAscend.get_group_size(),
|
||||||
|
rank_id=MultiAscend.get_rank(),
|
||||||
|
bucket=config.buckets,
|
||||||
|
shuffle=False)
|
||||||
|
_build_training_pipeline(preprocessed_data)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_rank_size = os.getenv("RANK_SIZE")
|
||||||
|
if _rank_size is not None and int(_rank_size) > 1:
|
||||||
|
train_paralle(args.data_path)
|
||||||
|
else:
|
||||||
|
train_single(args.data_path)
|
Loading…
Reference in new issue