parent
dabb82ec7a
commit
c3f3f6e67a
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############test textcnn example on movie review#################
|
||||
python eval.py
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg
|
||||
from src.textcnn import TextCNN
|
||||
from src.dataset import MovieReview
|
||||
|
||||
parser = argparse.ArgumentParser(description='TextCNN')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
dataset = instance.create_test_dataset(batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len,
|
||||
num_classes=cfg.num_classes, vec_length=cfg.vec_length)
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001,
|
||||
weight_decay=cfg.weight_decay)
|
||||
|
||||
if args_opt.checkpoint_path is not None:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
|
||||
else:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python eval.py --checkpoint_path="$1" > eval.log 2>&1 &
|
||||
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python train.py > train.log 2>&1 &
|
||||
@ -0,0 +1,34 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cfg = edict({
|
||||
'name': 'movie review',
|
||||
'pre_trained': False,
|
||||
'num_classes': 2,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 4,
|
||||
'weight_decay': 3e-5,
|
||||
'data_path': './data/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 7,
|
||||
'keep_checkpoint_max': 1,
|
||||
'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt',
|
||||
'word_len': 51,
|
||||
'vec_length': 40
|
||||
})
|
||||
@ -0,0 +1,212 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import codecs
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
|
||||
class Generator():
|
||||
def __init__(self, input_list):
|
||||
self.input_list = input_list
|
||||
|
||||
def __getitem__(self, item):
|
||||
return np.array(self.input_list[item][0], dtype=np.int32), np.array(self.input_list[item][1], dtype=np.int32)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_list)
|
||||
|
||||
|
||||
class MovieReview:
|
||||
"""
|
||||
preprocess MR dataset
|
||||
"""
|
||||
def __init__(self, root_dir, maxlen, split):
|
||||
"""
|
||||
input:
|
||||
root_dir: the root directory path of the MR dataset
|
||||
maxlen: set the max length of the sentence
|
||||
split: set the ratio of training set to testing set
|
||||
rank: the logic order of the worker
|
||||
size: the worker num
|
||||
"""
|
||||
self.path = root_dir
|
||||
self.feelMap = {
|
||||
'neg': 0,
|
||||
'pos': 1
|
||||
}
|
||||
self.files = []
|
||||
self.doConvert = False
|
||||
mypath = Path(self.path)
|
||||
|
||||
if not mypath.exists() or not mypath.is_dir():
|
||||
print("please check the root_dir!")
|
||||
raise ValueError
|
||||
|
||||
# walk through the root_dir
|
||||
for root, _, filename in os.walk(self.path):
|
||||
for each in filename:
|
||||
self.files.append(os.path.join(root, each))
|
||||
break
|
||||
|
||||
# check whether get two files
|
||||
if len(self.files) != 2:
|
||||
print("There are {} files in the root_dir".format(len(self.files)))
|
||||
raise ValueError
|
||||
|
||||
# begin to read data
|
||||
self.word_num = 0
|
||||
self.maxlen = 0
|
||||
self.minlen = float("inf")
|
||||
self.maxlen = float("-inf")
|
||||
self.Pos = []
|
||||
self.Neg = []
|
||||
for filename in self.files:
|
||||
f = codecs.open(filename, 'r')
|
||||
ff = f.read()
|
||||
file_object = codecs.open(filename, 'w', 'utf-8')
|
||||
file_object.write(ff)
|
||||
self.read_data(filename)
|
||||
self.PosNeg = self.Pos + self.Neg
|
||||
self.text2vec(maxlen=maxlen)
|
||||
self.split_dataset(split=split)
|
||||
|
||||
def read_data(self, filePath):
|
||||
"""
|
||||
read text into memory
|
||||
|
||||
input:
|
||||
filePath: the path where the data is stored in
|
||||
"""
|
||||
with open(filePath, 'r') as f:
|
||||
for sentence in f.readlines():
|
||||
sentence = sentence.replace('\n', '')\
|
||||
.replace('"', '')\
|
||||
.replace('\'', '')\
|
||||
.replace('.', '')\
|
||||
.replace(',', '')\
|
||||
.replace('[', '')\
|
||||
.replace(']', '')\
|
||||
.replace('(', '')\
|
||||
.replace(')', '')\
|
||||
.replace(':', '')\
|
||||
.replace('--', '')\
|
||||
.replace('-', '')\
|
||||
.replace('\\', '')\
|
||||
.replace('0', '')\
|
||||
.replace('1', '')\
|
||||
.replace('2', '')\
|
||||
.replace('3', '')\
|
||||
.replace('4', '')\
|
||||
.replace('5', '')\
|
||||
.replace('6', '')\
|
||||
.replace('7', '')\
|
||||
.replace('8', '')\
|
||||
.replace('9', '')\
|
||||
.replace('`', '')\
|
||||
.replace('=', '')\
|
||||
.replace('$', '')\
|
||||
.replace('/', '')\
|
||||
.replace('*', '')\
|
||||
.replace(';', '')\
|
||||
.replace('<b>', '')\
|
||||
.replace('%', '')
|
||||
sentence = sentence.split(' ')
|
||||
sentence = list(filter(lambda x: x, sentence))
|
||||
if sentence:
|
||||
self.word_num += len(sentence)
|
||||
self.maxlen = self.maxlen if self.maxlen >= len(sentence) else len(sentence)
|
||||
self.minlen = self.minlen if self.minlen <= len(sentence) else len(sentence)
|
||||
if 'pos' in filePath:
|
||||
self.Pos.append([sentence, self.feelMap['pos']])
|
||||
else:
|
||||
self.Neg.append([sentence, self.feelMap['neg']])
|
||||
|
||||
def text2vec(self, maxlen):
|
||||
"""
|
||||
convert the sentence into a vector in an int type
|
||||
|
||||
input:
|
||||
maxlen: max length of the sentence
|
||||
"""
|
||||
# Vocab = {word : index}
|
||||
self.Vocab = dict()
|
||||
|
||||
# self.Vocab['None']
|
||||
for SentenceLabel in self.Pos+self.Neg:
|
||||
vector = [0]*maxlen
|
||||
for index, word in enumerate(SentenceLabel[0]):
|
||||
if index >= maxlen:
|
||||
break
|
||||
if word not in self.Vocab.keys():
|
||||
self.Vocab[word] = len(self.Vocab)
|
||||
vector[index] = len(self.Vocab) - 1
|
||||
else:
|
||||
vector[index] = self.Vocab[word]
|
||||
SentenceLabel[0] = vector
|
||||
self.doConvert = True
|
||||
|
||||
def split_dataset(self, split):
|
||||
"""
|
||||
split the dataset into training set and test set
|
||||
input:
|
||||
split: the ratio of training set to test set
|
||||
rank: logic order
|
||||
size: device num
|
||||
"""
|
||||
trunk_pos_size = math.ceil((1-split)*len(self.Pos))
|
||||
trunk_neg_size = math.ceil((1-split)*len(self.Neg))
|
||||
trunk_num = int(1/(1-split))
|
||||
pos_temp = list()
|
||||
neg_temp = list()
|
||||
for index in range(trunk_num):
|
||||
pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
|
||||
neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
|
||||
self.test = pos_temp.pop(2)+neg_temp.pop(2)
|
||||
self.train = [i for item in pos_temp+neg_temp for i in item]
|
||||
|
||||
random.shuffle(self.train)
|
||||
# random.shuffle(self.test)
|
||||
|
||||
def get_dict_len(self):
|
||||
"""
|
||||
get number of different words in the whole dataset
|
||||
"""
|
||||
if self.doConvert:
|
||||
return len(self.Vocab)
|
||||
return -1
|
||||
#else:
|
||||
# print("Haven't finished Text2Vec")
|
||||
# return -1
|
||||
|
||||
def create_train_dataset(self, epoch_size, batch_size):
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.train),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
dataset = dataset.repeat(epoch_size)
|
||||
return dataset
|
||||
|
||||
def create_test_dataset(self, batch_size):
|
||||
dataset = ds.GeneratorDataset(source=Generator(input_list=self.test),
|
||||
column_names=["data", "label"], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
@ -0,0 +1,155 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""TextCNN"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.cell import Cell
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyExpand(Cell):
|
||||
r"""
|
||||
Computes softmax cross entropy between logits and labels. Implemented by expanded formula.
|
||||
|
||||
This is a wrapper of several functions.
|
||||
|
||||
.. math::
|
||||
\ell(x_i, t_i) = -log\left(\frac{\exp(x_{t_i})}{\sum_j \exp(x_j)}\right),
|
||||
where :math:`x_i` is a 1D score Tensor, :math:`t_i` is the target class.
|
||||
|
||||
Note:
|
||||
When argument sparse is set to True, the format of label is the index
|
||||
range from :math:`0` to :math:`C - 1` instead of one-hot vectors.
|
||||
|
||||
Args:
|
||||
sparse(bool): Specifies whether labels use sparse format or not. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar tensor including the mean loss.
|
||||
|
||||
Examples:
|
||||
>>> loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
|
||||
>>> input_data = Tensor(np.ones([64, 512]), dtype=mindspore.float32)
|
||||
>>> label = Tensor(np.ones([64]), dtype=mindspore.int32)
|
||||
>>> loss(input_data, label)
|
||||
"""
|
||||
def __init__(self, sparse=False):
|
||||
super(SoftmaxCrossEntropyExpand, self).__init__()
|
||||
self.exp = P.Exp()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mindspore.float32)
|
||||
self.off_value = Tensor(0.0, mindspore.float32)
|
||||
self.div = P.Div()
|
||||
self.log = P.Log()
|
||||
self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
|
||||
self.mul = P.Mul()
|
||||
self.mul2 = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
self.sparse = sparse
|
||||
self.reduce_max = P.ReduceMax(keep_dims=True)
|
||||
self.sub = P.Sub()
|
||||
|
||||
def construct(self, logit, label):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
logit_max = self.reduce_max(logit, -1)
|
||||
exp = self.exp(self.sub(logit, logit_max))
|
||||
exp_sum = self.reduce_sum(exp, -1)
|
||||
softmax_result = self.div(exp, exp_sum)
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
|
||||
softmax_result_log = self.log(softmax_result)
|
||||
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
|
||||
loss = self.mul2(F.scalar_to_array(-1.0), loss)
|
||||
loss = self.reduce_mean(loss, -1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def make_conv_layer(kernel_size):
|
||||
weight_shape = (96, 1, *kernel_size)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channels=1, out_channels=96, kernel_size=kernel_size, padding=1,
|
||||
pad_mode="pad", weight_init=weight, has_bias=True)
|
||||
|
||||
|
||||
class TextCNN(nn.Cell):
|
||||
"""
|
||||
TextCNN architecture
|
||||
"""
|
||||
def __init__(self, vocab_len, word_len, num_classes, vec_length):
|
||||
super(TextCNN, self).__init__()
|
||||
self.vec_length = vec_length
|
||||
self.word_len = word_len
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.unsqueeze = P.ExpandDims()
|
||||
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table='normal')
|
||||
|
||||
self.slice = P.Slice()
|
||||
self.layer1 = self.make_layer(kernel_height=3)
|
||||
self.layer2 = self.make_layer(kernel_height=4)
|
||||
self.layer3 = self.make_layer(kernel_height=5)
|
||||
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
self.fc = nn.Dense(96*3, self.num_classes)
|
||||
self.drop = nn.Dropout(keep_prob=0.5)
|
||||
self.print = P.Print()
|
||||
self.reducemean = P.ReduceMax(keep_dims=False)
|
||||
|
||||
def make_layer(self, kernel_height):
|
||||
return nn.SequentialCell(
|
||||
[
|
||||
make_conv_layer((kernel_height, self.vec_length)), nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=(self.word_len-kernel_height+1, 1)),
|
||||
]
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
x = self.unsqueeze(x, 1)
|
||||
x = self.embedding(x)
|
||||
x1 = self.layer1(x)
|
||||
x2 = self.layer2(x)
|
||||
x3 = self.layer3(x)
|
||||
|
||||
x1 = self.reducemean(x1, (2, 3))
|
||||
x2 = self.reducemean(x2, (2, 3))
|
||||
x3 = self.reducemean(x3, (2, 3))
|
||||
|
||||
x = self.concat((x1, x2, x3))
|
||||
x = self.drop(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
@ -0,0 +1,77 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train textcnn example on movie review########################
|
||||
python train.py
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import cfg
|
||||
from src.textcnn import TextCNN
|
||||
from src.textcnn import SoftmaxCrossEntropyExpand
|
||||
from src.dataset import MovieReview
|
||||
|
||||
parser = argparse.ArgumentParser(description='TextCNN')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
rank = 0
|
||||
# set context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
|
||||
dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
learning_rate = []
|
||||
warm_up = [1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in
|
||||
range(math.floor(cfg.epoch_size / 5))]
|
||||
shrink = [1e-3 / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]
|
||||
normal_run = [1e-3 for _ in range(batch_num) for i in
|
||||
range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))]
|
||||
learning_rate = learning_rate + warm_up + normal_run + shrink
|
||||
|
||||
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len,
|
||||
num_classes=cfg.num_classes, vec_length=cfg.vec_length)
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=learning_rate, weight_decay=cfg.weight_decay)
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2),
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
Loading…
Reference in new issue