!10290 add fasttext model to model zoo
From: @zhaojichen Reviewed-by: Signed-off-by:pull/10290/MERGE
commit
1150ae3376
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,118 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for Evaluation"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import context
|
||||
from src.fasttext_model import FastText
|
||||
parser = argparse.ArgumentParser(description='fasttext')
|
||||
parser.add_argument('--data_path', type=str, help='infer dataset path..')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag',
|
||||
help='dataset name. eg. ag, dbpedia')
|
||||
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
args = parser.parse_args()
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config
|
||||
target_label1 = ['0', '1', '2', '3']
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config
|
||||
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config
|
||||
target_label1 = ['0', '1']
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
class FastTextInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of FastText network infer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): FastText model.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor], predicted_ids
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(FastTextInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||
|
||||
def construct(self, src_tokens, src_tokens_lengths):
|
||||
"""construct fasttext infer cell"""
|
||||
prediction = self.network(src_tokens, src_tokens_lengths)
|
||||
predicted_idx = self.log_softmax(prediction)
|
||||
predicted_idx, _ = self.argmax(predicted_idx)
|
||||
|
||||
return predicted_idx
|
||||
|
||||
def load_infer_dataset(batch_size, datafile):
|
||||
"""data loader for infer"""
|
||||
ds = de.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||
ds = ds.batch(batch_size=batch_size, drop_remainder=True)
|
||||
|
||||
return ds
|
||||
|
||||
def run_fasttext_infer():
|
||||
"""run infer with FastText"""
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
|
||||
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
|
||||
parameter_dict = load_checkpoint(args.model_ckpt)
|
||||
load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
|
||||
|
||||
ft_infer = FastTextInferCell(fasttext_model)
|
||||
|
||||
model = Model(ft_infer)
|
||||
|
||||
predictions = []
|
||||
target_sens = []
|
||||
|
||||
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
target_sens.append(batch['label_idx'])
|
||||
src_tokens = Tensor(batch['src_tokens'], mstype.int32)
|
||||
src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32)
|
||||
predicted_idx = model.predict(src_tokens, src_tokens_length)
|
||||
predictions.append(predicted_idx.asnumpy())
|
||||
|
||||
from sklearn.metrics import accuracy_score, classification_report
|
||||
target_sens = np.array(target_sens).flatten()
|
||||
predictions = np.array(predictions).flatten()
|
||||
acc = accuracy_score(target_sens, predictions)
|
||||
|
||||
result_report = classification_report(target_sens, predictions, target_names=target_label1)
|
||||
print("********Accuracy: ", acc)
|
||||
print(result_report)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_fasttext_infer()
|
@ -0,0 +1,3 @@
|
||||
spacy
|
||||
sklearn
|
||||
en_core_web_lg
|
@ -0,0 +1,83 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh create_dataset.sh SOURCE_DATASET_PATH DATASET_NAME"
|
||||
echo "for example: sh create_dataset.sh /home/workspace/ag_news_csv ag"
|
||||
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ulimit -u unlimited
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
SOURCE_DATASET_PATH=$(get_real_path $1)
|
||||
DATASET_NAME=$2
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ $DATASET_NAME == 'ag' ];
|
||||
then
|
||||
echo "Begin to process ag news data"
|
||||
if [ -d "ag" ];
|
||||
then
|
||||
rm -rf ./ag
|
||||
fi
|
||||
mkdir ./ag
|
||||
cd ./ag || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 4 --max_len 467 --bucket [64,128,467] --test_bucket [467]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == 'dbpedia' ];
|
||||
then
|
||||
echo "Begin to process dbpedia data"
|
||||
if [ -d "dbpedia" ];
|
||||
then
|
||||
rm -rf ./dbpedia
|
||||
fi
|
||||
mkdir ./dbpedia
|
||||
cd ./dbpedia || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == 'yelp_p' ];
|
||||
then
|
||||
echo "Begin to process ag news data"
|
||||
if [ -d "yelp_p" ];
|
||||
then
|
||||
rm -rf ./yelp_p
|
||||
fi
|
||||
mkdir ./yelp_p
|
||||
cd ./yelp_p || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distributed_train.sh DATASET_PATH RANK_TABLE_PATH"
|
||||
echo "for example: sh run_distributed_train.sh /home/workspace/ag /home/workspace/rank_table_file.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$(basename $DATASET)
|
||||
RANK_TABLE_PATH=$(get_real_path $2)
|
||||
echo $DATANAME
|
||||
if [ ! -d $DATASET ]
|
||||
then
|
||||
echo "Error: DATA_PATH=$DATASET is not a file"
|
||||
exit 1
|
||||
fi
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_PATH
|
||||
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export DEVICE_NUM=8
|
||||
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
rm -rf ${current_exec_path}/device$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i
|
||||
cp ../../*.py ./
|
||||
cp -r ../../src ./
|
||||
cp -r ../*.sh ./
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ${current_exec_path}
|
||||
done
|
||||
cd ${current_exec_path}
|
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_eval.sh DATASET_PATH DATASET_NAME MODEL_CKPT"
|
||||
echo "for example: sh run_eval.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$2
|
||||
MODEL_CKPT=$(get_real_path $3)
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../scripts/*.sh ./eval
|
||||
cd ./eval || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DATASET_PATH"
|
||||
echo "for example: sh run_standalone_train.sh /home/workspace/ag"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$(basename $DATASET)
|
||||
echo $DATANAME
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../scripts/*.sh ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,72 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" :===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_yelpp = ed({
|
||||
'vocab_size': 6414979,
|
||||
'buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 2,
|
||||
'epoch': 5,
|
||||
'lr': 0.02,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 549,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_db = ed({
|
||||
'vocab_size': 6596536,
|
||||
'buckets': [128, 512, 3013],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 14,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 548,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_ag = ed({
|
||||
'vocab_size': 1383812,
|
||||
'buckets': [64, 128, 467],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 4,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 115,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 116,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,70 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText model."""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import XavierUniform
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
class FastText(nn.Cell):
|
||||
"""
|
||||
FastText model
|
||||
Args:
|
||||
|
||||
vocab_size: vocabulary size
|
||||
embedding_dims: The size of each embedding vector
|
||||
num_class: number of labels
|
||||
"""
|
||||
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||
super(FastText, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embeding_dims = embedding_dims
|
||||
self.num_class = num_class
|
||||
self.embeding_func = nn.Embedding(vocab_size=self.vocab_size,
|
||||
embedding_size=self.embeding_dims,
|
||||
padding_idx=0, embedding_table='Zeros')
|
||||
self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class,
|
||||
weight_init=XavierUniform(1)).to_float(mstype.float16)
|
||||
self.reducesum = P.ReduceSum()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.realdiv = P.RealDiv()
|
||||
self.fill = P.Fill()
|
||||
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||
def construct(self, src_tokens, src_token_length):
|
||||
"""
|
||||
construct network
|
||||
Args:
|
||||
|
||||
src_tokens: source sentences
|
||||
src_token_length: source sentences length
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs
|
||||
"""
|
||||
src_tokens = self.embeding_func(src_tokens)
|
||||
embeding = self.reducesum(src_tokens, 1)
|
||||
|
||||
length_tiled = self.tile(src_token_length, (1, self.embeding_dims))
|
||||
|
||||
embeding = self.realdiv(embeding, length_tiled)
|
||||
|
||||
embeding = self.cast(embeding, mstype.float16)
|
||||
classifer = self.fc(embeding)
|
||||
classifer = self.cast(classifer, mstype.float32)
|
||||
|
||||
return classifer
|
@ -0,0 +1,142 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for train"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import nn
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore import context
|
||||
from src.fasttext_model import FastText
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
class FastTextNetWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide FastText training loss
|
||||
|
||||
Args:
|
||||
vocab_size: vocabulary size
|
||||
embedding_dims: The size of each embedding vector
|
||||
num_class: number of labels
|
||||
"""
|
||||
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||
super(FastTextNetWithLoss, self).__init__()
|
||||
self.fasttext = FastText(vocab_size, embedding_dims, num_class)
|
||||
self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, src_tokens, src_tokens_lengths, label_idx):
|
||||
"""
|
||||
FastText network with loss.
|
||||
"""
|
||||
predict_score = self.fasttext(src_tokens, src_tokens_lengths)
|
||||
label_idx = self.squeeze(label_idx)
|
||||
predict_score = self.loss_func(predict_score, label_idx)
|
||||
|
||||
return predict_score
|
||||
|
||||
|
||||
class FastTextTrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of fasttext network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
||||
def construct(self,
|
||||
src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag)
|
||||
grads = self.grad(self.network, weights)(src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText data loader"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
def load_dataset(dataset_path,
|
||||
batch_size,
|
||||
epoch_count=1,
|
||||
rank_size=1,
|
||||
rank_id=0,
|
||||
bucket=None,
|
||||
shuffle=True):
|
||||
"""dataset loader"""
|
||||
def batch_per_bucket(bucket_length, input_file):
|
||||
input_file = input_file +'/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
|
||||
if not input_file:
|
||||
raise FileNotFoundError("input file parameter must not be empty.")
|
||||
|
||||
ds = de.MindDataset(input_file,
|
||||
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||
shuffle=shuffle,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=8)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print(f"Dataset size: {ori_dataset_size}")
|
||||
repeat_count = epoch_count
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||
|
||||
ds = ds.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
|
||||
ds = ds.batch(batch_size, drop_remainder=False)
|
||||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
for i, _ in enumerate(bucket):
|
||||
bucket_len = bucket[i]
|
||||
ds_per = batch_per_bucket(bucket_len, dataset_path)
|
||||
if i == 0:
|
||||
ds = ds_per
|
||||
else:
|
||||
ds = ds + ds_per
|
||||
ds = ds.shuffle(ds.get_dataset_size())
|
||||
ds.channel_name = 'fasttext'
|
||||
|
||||
return ds
|
@ -0,0 +1,54 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate utilities."""
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
warmup_steps (int): Warmup steps.
|
||||
decay_steps (int): Decay steps.
|
||||
total_update_num (int): Total update steps.
|
||||
min_lr (float): Min learning.
|
||||
power (float): Power factor.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < total_update_num:
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
decay_steps = decay_steps
|
||||
for step in range(_start_step, total_update_num):
|
||||
_step = step - _start_step
|
||||
ratio = ceil(_step / decay_steps)
|
||||
ratio = 1 if ratio < 1 else ratio
|
||||
_decay_steps = decay_steps * ratio
|
||||
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
|
||||
|
||||
return lrs
|
@ -0,0 +1,202 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for train"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.load_dataset import load_dataset
|
||||
from src.lr_schedule import polynomial_decay_scheduler
|
||||
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
set_seed(5)
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
rank_id = os.getenv('DEVICE_ID')
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1, rank_ids=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.rank_id = rank_ids
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = get_ms_timestamp()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
global time_stamp_first
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
|
||||
time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs.asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def _build_training_pipeline(pre_dataset):
|
||||
"""
|
||||
Build training pipeline
|
||||
|
||||
Args:
|
||||
pre_dataset: preprocessed dataset
|
||||
"""
|
||||
net_with_loss = FastTextNetWithLoss(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
net_with_loss.init_parameters_data()
|
||||
if config.pretrain_ckpt_dir:
|
||||
parameter_dict = load_checkpoint(config.pretrain_ckpt_dir)
|
||||
load_param_into_net(net_with_loss, parameter_dict)
|
||||
if pre_dataset is None:
|
||||
raise ValueError("pre-process dataset must be provided")
|
||||
|
||||
#get learning rate
|
||||
update_steps = config.epoch * pre_dataset.get_dataset_size()
|
||||
decay_steps = pre_dataset.get_dataset_size()
|
||||
rank_size = os.getenv("RANK_SIZE")
|
||||
if isinstance(rank_size, int):
|
||||
raise ValueError("RANK_SIZE must be integer")
|
||||
if rank_size is not None and int(rank_size) > 1:
|
||||
base_lr = config.lr
|
||||
else:
|
||||
base_lr = config.lr / 10
|
||||
print("+++++++++++Total update steps ", update_steps)
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=base_lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
|
||||
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999)
|
||||
|
||||
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
net_with_grads.set_train(True)
|
||||
model = Model(net_with_grads)
|
||||
loss_monitor = LossCallBack(rank_ids=rank_id)
|
||||
dataset_size = pre_dataset.get_dataset_size()
|
||||
time_monitor = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
callbacks = [time_monitor, loss_monitor]
|
||||
if rank_size is None or int(rank_size) == 1:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
print("Prepare to Training....")
|
||||
epoch_size = pre_dataset.get_repeat_count()
|
||||
print("Epoch size ", epoch_size)
|
||||
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||
model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def train_single(input_file_path):
|
||||
"""
|
||||
Train model on single device
|
||||
Args:
|
||||
input_file_path: preprocessed dataset path
|
||||
"""
|
||||
print("Staring training on single device.")
|
||||
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epoch_count,
|
||||
bucket=config.buckets)
|
||||
_build_training_pipeline(preprocessed_data)
|
||||
|
||||
|
||||
def set_parallel_env():
|
||||
context.reset_auto_parallel_context()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
gradients_mean=True)
|
||||
def train_paralle(input_file_path):
|
||||
"""
|
||||
Train model on multi device
|
||||
Args:
|
||||
input_file_path: preprocessed dataset path
|
||||
"""
|
||||
set_parallel_env()
|
||||
print("Starting traning on mutiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
|
||||
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epoch_count,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank(),
|
||||
bucket=config.buckets,
|
||||
shuffle=False)
|
||||
_build_training_pipeline(preprocessed_data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
_rank_size = os.getenv("RANK_SIZE")
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
train_paralle(args.data_path)
|
||||
else:
|
||||
train_single(args.data_path)
|
Loading…
Reference in new issue