upload textcnn

pull/9235/head
xinyun Fan 5 years ago
parent dabb82ec7a
commit c3f3f6e67a

@ -0,0 +1,177 @@
# Contents
- [TextCNN Description](#textcnn-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [TextCNN Description](#contents)
TextCNN is an algorithm that uses convolutional neural networks to classify text. It was proposed by Yoon Kim in the article "Convolutional Neural Networks for Sentence Classification" in 2014. It is widely used in various tasks of text classification (such as sentiment analysis). It has become the standard benchmark for the new text classification framework. Each module of TextCNN can complete text classification tasks independently, and it is convenient for distributed configuration and parallel execution. TextCNN is very suitable for the semantic analysis of short texts such as Weibo/News/E-commerce reviews and video bullet screens.
[Paper](https://arxiv.org/abs/1408.5882): Kim Y. Convolutional neural networks for sentence classification[J]. arXiv preprint arXiv:1408.5882, 2014.
# [Model Architecture](#contents)
The basic network structure design of TextCNN can refer to the paper "Convolutional Neural Networks for Sentence Classification". The specific implementation takes reading a sentence "I like this movie very much!" as an example. First, the word segmentation algorithm is used to divide the words into 7 words, and then the words in each part are expanded into a five-dimensional vector through the embedding method. Then use different convolution kernels ([3,4,5]*5) to perform convolution operations on them to obtain feature maps. The default number of convolution kernels is 2. Then use the maxpool operation to pool all the feature maps, and finally merge the pooling result into a one-dimensional feature vector through the connection operation. At last, it can be divided into 2 categories with softmax, and the positive/negative emotions are obtained.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [Movie Review Data](<http://www.cs.cornell.edu/people/pabo/movie-review-data/>)
- Dataset size1.18M5331 positive and 5331 negative processed sentences / snippets.
- Train1.06M, 9596 sentences / snippets
- Test0.12M, 1066 sentences / snippets
- Data formattext
- Please click [here](<http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz>) to download the data, and change the files into utf-8. Then put it into the `data` directory.
- NoteData will be processed in src/dataset.py
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- runing on Ascend
```python
# run training example
python train.py > train.log 2>&1 &
OR
sh scripts/run_train.sh
# run evaluation example
python eval.py > eval.log 2>&1 &
OR
sh scripts/run_eval.sh ckpt_path
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```bash
├── model_zoo
├── README.md // descriptions about all the models
├── textcnn
├── README.md // descriptions about textcnn
├──scripts
│ ├── run_train.sh // shell script for distributed on Ascend
│ ├── run_eval.sh // shell script for evaluation on Ascend
├── src
│ ├── dataset.py // Processing dataset
│ ├── textcnn.py // textcnn architecture
│ ├── config.py // parameter configuration
├── train.py // training script
├── eval.py // evaluation script
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for movie review dataset
```python
'pre_trained': 'False' # whether training based on the pre-trained model
'nump_classes': 2 # the number of classes in the dataset
'batch_size': 64 # training batch size
'epoch_size': 4 # total training epochs
'weight_decay': 3e-5 # weight decay value
'data_path': './data/' # absolute full path to the train and evaluation datasets
'device_target': 'Ascend' # device running the program
'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
'keep_checkpoint_max': 1 # only keep the last keep_checkpoint_max checkpoint
'checkpoint_path': './train_textcnn.ckpt' # the absolute full path to save the checkpoint file
'word_len': 51 # The length of the word
'vec_length': 40 # The length of the vector
```
For more configuration details, please refer the script `config.py`.
## [Training Process](#contents)
- running on Ascend
```python
python train.py > train.log 2>&1 &
OR
sh scripts/run_train.sh
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files in `ckpt`. The loss value will be achieved as follows:
```python
# grep "loss is " train.log
epoch: 1 step 149, loss is 0.6194226145744324
epoch: 2 step 149, loss is 0.38729554414749146
...
```
The model checkpoint will be saved in the `ckpt` directory.
## [Evaluation Process](#contents)
- evaluation on movie review dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/textcnn/ckpt/train_textcnn.ckpt".
```python
python eval.py --checkpoint_path=ckpt_path > eval.log 2>&1 &
OR
sh scripts/run_eval.sh ckpt_path
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```python
# grep "accuracy: " eval.log
accuracy: {'acc': 0.7971428571428572}
```
# [Model Description](#contents)
## [Performance](#contents)
### TextCNN on Movie Review Dataset
| Parameters | Ascend |
| ------------------- | ----------------------------------------------------- |
| Model Version | TextCNN |
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G |
| uploaded Date | 11/10/2020 (month/day/year) |
| MindSpore Version | 1.0.1 |
| Dataset | Movie Review Data |
| Training Parameters | epoch=4, steps=149, batch_size = 64 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.1724 |
| Speed | 1pc: 12.069 ms/step |
| Total time | 1pc: 13s |
| Scripts | [textcnn script](https://gitee.com/xinyunfan/textcnn) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test textcnn example on movie review#################
python eval.py
"""
import argparse
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cfg
from src.textcnn import TextCNN
from src.dataset import MovieReview
parser = argparse.ArgumentParser(description='TextCNN')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args()
if __name__ == '__main__':
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if device_target == "Ascend":
context.set_context(device_id=cfg.device_id)
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_test_dataset(batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len,
num_classes=cfg.num_classes, vec_length=cfg.vec_length)
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001,
weight_decay=cfg.weight_decay)
if args_opt.checkpoint_path is not None:
param_dict = load_checkpoint(args_opt.checkpoint_path)
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
else:
param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
acc = model.eval(dataset)
print("accuracy: ", acc)

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python eval.py --checkpoint_path="$1" > eval.log 2>&1 &

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python train.py > train.log 2>&1 &

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
cfg = edict({
'name': 'movie review',
'pre_trained': False,
'num_classes': 2,
'batch_size': 64,
'epoch_size': 4,
'weight_decay': 3e-5,
'data_path': './data/',
'device_target': 'Ascend',
'device_id': 7,
'keep_checkpoint_max': 1,
'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt',
'word_len': 51,
'vec_length': 40
})

@ -0,0 +1,212 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import math
import random
import codecs
from pathlib import Path
import numpy as np
import mindspore.dataset as ds
class Generator():
def __init__(self, input_list):
self.input_list = input_list
def __getitem__(self, item):
return np.array(self.input_list[item][0], dtype=np.int32), np.array(self.input_list[item][1], dtype=np.int32)
def __len__(self):
return len(self.input_list)
class MovieReview:
"""
preprocess MR dataset
"""
def __init__(self, root_dir, maxlen, split):
"""
input:
root_dir: the root directory path of the MR dataset
maxlen: set the max length of the sentence
split: set the ratio of training set to testing set
rank: the logic order of the worker
size: the worker num
"""
self.path = root_dir
self.feelMap = {
'neg': 0,
'pos': 1
}
self.files = []
self.doConvert = False
mypath = Path(self.path)
if not mypath.exists() or not mypath.is_dir():
print("please check the root_dir!")
raise ValueError
# walk through the root_dir
for root, _, filename in os.walk(self.path):
for each in filename:
self.files.append(os.path.join(root, each))
break
# check whether get two files
if len(self.files) != 2:
print("There are {} files in the root_dir".format(len(self.files)))
raise ValueError
# begin to read data
self.word_num = 0
self.maxlen = 0
self.minlen = float("inf")
self.maxlen = float("-inf")
self.Pos = []
self.Neg = []
for filename in self.files:
f = codecs.open(filename, 'r')
ff = f.read()
file_object = codecs.open(filename, 'w', 'utf-8')
file_object.write(ff)
self.read_data(filename)
self.PosNeg = self.Pos + self.Neg
self.text2vec(maxlen=maxlen)
self.split_dataset(split=split)
def read_data(self, filePath):
"""
read text into memory
input:
filePath: the path where the data is stored in
"""
with open(filePath, 'r') as f:
for sentence in f.readlines():
sentence = sentence.replace('\n', '')\
.replace('"', '')\
.replace('\'', '')\
.replace('.', '')\
.replace(',', '')\
.replace('[', '')\
.replace(']', '')\
.replace('(', '')\
.replace(')', '')\
.replace(':', '')\
.replace('--', '')\
.replace('-', '')\
.replace('\\', '')\
.replace('0', '')\
.replace('1', '')\
.replace('2', '')\
.replace('3', '')\
.replace('4', '')\
.replace('5', '')\
.replace('6', '')\
.replace('7', '')\
.replace('8', '')\
.replace('9', '')\
.replace('`', '')\
.replace('=', '')\
.replace('$', '')\
.replace('/', '')\
.replace('*', '')\
.replace(';', '')\
.replace('<b>', '')\
.replace('%', '')
sentence = sentence.split(' ')
sentence = list(filter(lambda x: x, sentence))
if sentence:
self.word_num += len(sentence)
self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
if 'pos' in filePath:
self.Pos.append([sentence, self.feelMap['pos']])
else:
self.Neg.append([sentence, self.feelMap['neg']])
def text2vec(self, maxlen):
"""
convert the sentence into a vector in an int type
input:
maxlen: max length of the sentence
"""
# Vocab = {word : index}
self.Vocab = dict()
# self.Vocab['None']
for SentenceLabel in self.Pos+self.Neg:
vector = [0]*maxlen
for index, word in enumerate(SentenceLabel[0]):
if index >= maxlen:
break
if word not in self.Vocab.keys():
self.Vocab[word] = len(self.Vocab)
vector[index] = len(self.Vocab) - 1
else:
vector[index] = self.Vocab[word]
SentenceLabel[0] = vector
self.doConvert = True
def split_dataset(self, split):
"""
split the dataset into training set and test set
input:
split: the ratio of training set to test set
rank: logic order
size: device num
"""
trunk_pos_size = math.ceil((1-split)*len(self.Pos))
trunk_neg_size = math.ceil((1-split)*len(self.Neg))
trunk_num = int(1/(1-split))
pos_temp = list()
neg_temp = list()
for index in range(trunk_num):
pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
self.test = pos_temp.pop(2)+neg_temp.pop(2)
self.train = [i for item in pos_temp+neg_temp for i in item]
random.shuffle(self.train)
# random.shuffle(self.test)
def get_dict_len(self):
"""
get number of different words in the whole dataset
"""
if self.doConvert:
return len(self.Vocab)
return -1
#else:
# print("Haven't finished Text2Vec")
# return -1
def create_train_dataset(self, epoch_size, batch_size):
dataset = ds.GeneratorDataset(source=Generator(input_list=self.train),
column_names=["data", "label"], shuffle=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.repeat(epoch_size)
return dataset
def create_test_dataset(self, batch_size):
dataset = ds.GeneratorDataset(source=Generator(input_list=self.test),
column_names=["data", "label"], shuffle=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
return dataset

@ -0,0 +1,155 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TextCNN"""
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.nn.cell import Cell
import mindspore.ops.functional as F
import mindspore
class SoftmaxCrossEntropyExpand(Cell):
r"""
Computes softmax cross entropy between logits and labels. Implemented by expanded formula.
This is a wrapper of several functions.
.. math::
\ell(x_i, t_i) = -log\left(\frac{\exp(x_{t_i})}{\sum_j \exp(x_j)}\right),
where :math:`x_i` is a 1D score Tensor, :math:`t_i` is the target class.
Note:
When argument sparse is set to True, the format of label is the index
range from :math:`0` to :math:`C - 1` instead of one-hot vectors.
Args:
sparse(bool): Specifies whether labels use sparse format or not. Default: False.
Inputs:
- **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
- **label** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
Outputs:
Tensor, a scalar tensor including the mean loss.
Examples:
>>> loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
>>> input_data = Tensor(np.ones([64, 512]), dtype=mindspore.float32)
>>> label = Tensor(np.ones([64]), dtype=mindspore.int32)
>>> loss(input_data, label)
"""
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mindspore.float32)
self.off_value = Tensor(0.0, mindspore.float32)
self.div = P.Div()
self.log = P.Log()
self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
self.mul = P.Mul()
self.mul2 = P.Mul()
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.sparse = sparse
self.reduce_max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
def construct(self, logit, label):
"""
construct
"""
logit_max = self.reduce_max(logit, -1)
exp = self.exp(self.sub(logit, logit_max))
exp_sum = self.reduce_sum(exp, -1)
softmax_result = self.div(exp, exp_sum)
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
softmax_result_log = self.log(softmax_result)
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
loss = self.mul2(F.scalar_to_array(-1.0), loss)
loss = self.reduce_mean(loss, -1)
return loss
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def make_conv_layer(kernel_size):
weight_shape = (96, 1, *kernel_size)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channels=1, out_channels=96, kernel_size=kernel_size, padding=1,
pad_mode="pad", weight_init=weight, has_bias=True)
class TextCNN(nn.Cell):
"""
TextCNN architecture
"""
def __init__(self, vocab_len, word_len, num_classes, vec_length):
super(TextCNN, self).__init__()
self.vec_length = vec_length
self.word_len = word_len
self.num_classes = num_classes
self.unsqueeze = P.ExpandDims()
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table='normal')
self.slice = P.Slice()
self.layer1 = self.make_layer(kernel_height=3)
self.layer2 = self.make_layer(kernel_height=4)
self.layer3 = self.make_layer(kernel_height=5)
self.concat = P.Concat(1)
self.fc = nn.Dense(96*3, self.num_classes)
self.drop = nn.Dropout(keep_prob=0.5)
self.print = P.Print()
self.reducemean = P.ReduceMax(keep_dims=False)
def make_layer(self, kernel_height):
return nn.SequentialCell(
[
make_conv_layer((kernel_height, self.vec_length)), nn.ReLU(),
nn.MaxPool2d(kernel_size=(self.word_len-kernel_height+1, 1)),
]
)
def construct(self, x):
"""
construct
"""
x = self.unsqueeze(x, 1)
x = self.embedding(x)
x1 = self.layer1(x)
x2 = self.layer2(x)
x3 = self.layer3(x)
x1 = self.reducemean(x1, (2, 3))
x2 = self.reducemean(x2, (2, 3))
x3 = self.reducemean(x3, (2, 3))
x = self.concat((x1, x2, x3))
x = self.drop(x)
x = self.fc(x)
return x

@ -0,0 +1,77 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train textcnn example on movie review########################
python train.py
"""
import argparse
import math
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cfg
from src.textcnn import TextCNN
from src.textcnn import SoftmaxCrossEntropyExpand
from src.dataset import MovieReview
parser = argparse.ArgumentParser(description='TextCNN')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.')
args_opt = parser.parse_args()
if __name__ == '__main__':
rank = 0
# set context
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size)
batch_num = dataset.get_dataset_size()
learning_rate = []
warm_up = [1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in
range(math.floor(cfg.epoch_size / 5))]
shrink = [1e-3 / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]
normal_run = [1e-3 for _ in range(batch_num) for i in
range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))]
learning_rate = learning_rate + warm_up + normal_run + shrink
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len,
num_classes=cfg.num_classes, vec_length=cfg.vec_length)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=learning_rate, weight_decay=cfg.weight_decay)
loss = SoftmaxCrossEntropyExpand(sparse=True)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2),
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
Loading…
Cancel
Save