From: @ttudu
Reviewed-by: 
Signed-off-by:
pull/10222/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 18ca7eaeb0

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -20,8 +20,9 @@ import os
import numpy as np import numpy as np
from src.config import lstm_cfg as cfg from src.config import lstm_cfg as cfg, lstm_cfg_ascend
from src.dataset import lstm_create_dataset, convert_to_mindrecord from src.dataset import lstm_create_dataset, convert_to_mindrecord
from src.lr_schedule import get_lr
from src.lstm import SentimentNet from src.lstm import SentimentNet
from mindspore import Tensor, nn, Model, context from mindspore import Tensor, nn, Model, context
from mindspore.nn import Accuracy from mindspore.nn import Accuracy
@ -40,8 +41,8 @@ if __name__ == '__main__':
help='path where the pre-process data is stored.') help='path where the pre-process data is stored.')
parser.add_argument('--ckpt_path', type=str, default=None, parser.add_argument('--ckpt_path', type=str, default=None,
help='the checkpoint file path used to evaluate model.') help='the checkpoint file path used to evaluate model.')
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'],
help='the target device to run, support "GPU", "CPU". Default: "GPU".') help='the target device to run, support "GPU", "CPU". Default: "Ascend".')
args = parser.parse_args() args = parser.parse_args()
context.set_context( context.set_context(
@ -49,11 +50,24 @@ if __name__ == '__main__':
save_graphs=False, save_graphs=False,
device_target=args.device_target) device_target=args.device_target)
if args.device_target == 'Ascend':
cfg = lstm_cfg_ascend
else:
cfg = lstm_cfg
if args.preprocess == "true": if args.preprocess == "true":
print("============== Starting Data Pre-processing ==============") print("============== Starting Data Pre-processing ==============")
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
# DynamicRNN in this network on Ascend platform only support the condition that the shape of input_size
# and hiddle_size is multiples of 16, this problem will be solved later.
if args.device_target == 'Ascend':
pad_num = int(np.ceil(cfg.embed_size / 16) * 16 - cfg.embed_size)
if pad_num > 0:
embedding_table = np.pad(embedding_table, [(0, 0), (0, pad_num)], 'constant')
cfg.embed_size = int(np.ceil(cfg.embed_size / 16) * 16)
network = SentimentNet(vocab_size=embedding_table.shape[0], network = SentimentNet(vocab_size=embedding_table.shape[0],
embed_size=cfg.embed_size, embed_size=cfg.embed_size,
num_hiddens=cfg.num_hiddens, num_hiddens=cfg.num_hiddens,
@ -64,13 +78,23 @@ if __name__ == '__main__':
batch_size=cfg.batch_size) batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False)
if cfg.dynamic_lr:
lr = Tensor(get_lr(global_step=cfg.global_step,
lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max,
warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.num_epochs,
steps_per_epoch=ds_eval.get_dataset_size(),
lr_adjust_epoch=cfg.lr_adjust_epoch))
else:
lr = cfg.learning_rate
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
loss_cb = LossMonitor() loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()}) model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False)
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
if args.device_target == "CPU": if args.device_target == "CPU":

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval_ascend.sh DEVICE_ID PREPROCESS_DIR CKPT_FILE"
echo "for example: bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt"
echo "=============================================================================================================="
DEVICE_ID=$1
PREPROCESS_DIR=$2
CKPT_FILE=$3
rm -rf eval
mkdir -p eval
cd eval
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
export DEVICE_ID=$DEVICE_ID
python ../../eval.py \
--device_target="Ascend" \
--preprocess=false \
--preprocess_path=$PREPROCESS_DIR \
--ckpt_path=$CKPT_FILE > log.txt 2>&1 &

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the script as: "
echo "bash run_eval_cpu.sh ACLIMDB_DIR GLOVE_DIR CKPT_FILE" echo "bash run_eval_cpu.sh ACLIMDB_DIR GLOVE_DIR CKPT_FILE"
echo "for example: bash run_eval_cpu.sh ./aclimdb ./glove_dir lstm-20_390.ckpt" echo "for example: bash run_eval_cpu.sh ./aclimdb ./glove_dir lstm-20_390.ckpt"
echo "==============================================================================================================" echo "=============================================================================================================="

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the script as: "
echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR CKPT_FILE" echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR CKPT_FILE"
echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt" echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt"
echo "==============================================================================================================" echo "=============================================================================================================="

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_train_ascend.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR"
echo "for example: bash run_train_ascend.sh 0 ./aclimdb ./glove_dir"
echo "=============================================================================================================="
DEVICE_ID=$1
ACLIMDB_DIR=$2
GLOVE_DIR=$3
mkdir -p train
cd train
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
export DEVICE_ID=$DEVICE_ID
python ../../train.py \
--device_target="Ascend" \
--aclimdb_path=$ACLIMDB_DIR \
--glove_path=$GLOVE_DIR \
--preprocess=true \
--preprocess_path=./preprocess > log.txt 2>&1 &

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the script as: "
echo "bash run_train_cpu.sh ACLIMDB_DIR GLOVE_DIR" echo "bash run_train_cpu.sh ACLIMDB_DIR GLOVE_DIR"
echo "for example: bash run_train_gpu.sh ./aclimdb ./glove_dir" echo "for example: bash run_train_gpu.sh ./aclimdb ./glove_dir"
echo "==============================================================================================================" echo "=============================================================================================================="

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the script as: "
echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR" echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR"
echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir" echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir"
echo "==============================================================================================================" echo "=============================================================================================================="

@ -20,6 +20,7 @@ from easydict import EasyDict as edict
# LSTM CONFIG # LSTM CONFIG
lstm_cfg = edict({ lstm_cfg = edict({
'num_classes': 2, 'num_classes': 2,
'dynamic_lr': False,
'learning_rate': 0.1, 'learning_rate': 0.1,
'momentum': 0.9, 'momentum': 0.9,
'num_epochs': 20, 'num_epochs': 20,
@ -31,3 +32,24 @@ lstm_cfg = edict({
'save_checkpoint_steps': 390, 'save_checkpoint_steps': 390,
'keep_checkpoint_max': 10 'keep_checkpoint_max': 10
}) })
# LSTM CONFIG IN ASCEND
lstm_cfg_ascend = edict({
'num_classes': 2,
'momentum': 0.9,
'num_epochs': 20,
'batch_size': 64,
'embed_size': 300,
'num_hiddens': 128,
'num_layers': 2,
'bidirectional': True,
'save_checkpoint_steps': 7800,
'keep_checkpoint_max': 10,
'dynamic_lr': True,
'lr_init': 0.05,
'lr_end': 0.01,
'lr_max': 0.1,
'lr_adjust_epoch': 6,
'warmup_epochs': 1,
'global_step': 0
})

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate schedule"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_adjust_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(float): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_adjust_epoch(int): lr adjust in lr_adjust_epoch, after that, the lr is lr_end
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
adjust_steps = lr_adjust_epoch * steps_per_epoch
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
elif i < adjust_steps:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (adjust_steps - warmup_steps))) / 2.
else:
lr = lr_end
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

@ -20,6 +20,8 @@ import numpy as np
from mindspore import Tensor, nn, context, Parameter, ParameterTuple from mindspore import Tensor, nn, context, Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.ops.functional as F
import mindspore.common.dtype as mstype
STACK_LSTM_DEVICE = ["CPU"] STACK_LSTM_DEVICE = ["CPU"]
@ -44,6 +46,28 @@ def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional)
h, c = tuple(h_list), tuple(c_list) h, c = tuple(h_list), tuple(c_list)
return h, c return h, c
def stack_lstm_default_state_ascend(batch_size, hidden_size, num_layers, bidirectional):
"""init default input."""
h_list = c_list = []
for _ in range(num_layers):
h_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
c_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
h_i = [h_fw]
c_i = [c_fw]
if bidirectional:
h_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
c_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
h_i.append(h_bw)
c_i.append(c_bw)
h_list.append(h_i)
c_list.append(c_i)
h, c = tuple(h_list), tuple(c_list)
return h, c
class StackLSTM(nn.Cell): class StackLSTM(nn.Cell):
""" """
@ -114,6 +138,128 @@ class StackLSTM(nn.Cell):
x = self.transpose(x, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
return x, (hn, cn) return x, (hn, cn)
class LSTM_Ascend(nn.Cell):
""" LSTM in Ascend. """
def __init__(self, bidirectional=False):
super(LSTM_Ascend, self).__init__()
self.bidirectional = bidirectional
self.dynamic_rnn = P.DynamicRNN(forget_bias=0.0)
self.reverseV2 = P.ReverseV2(axis=[0])
self.concat = P.Concat(2)
def construct(self, x, h, c, w_f, b_f, w_b=None, b_b=None):
"""construct"""
x = F.cast(x, mstype.float16)
if self.bidirectional:
y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0])
r_x = self.reverseV2(x)
y2, h2, c2, _, _, _, _, _ = self.dynamic_rnn(r_x, w_b, b_b, None, h[1], c[1])
y2 = self.reverseV2(y2)
output = self.concat((y1, y2))
hn = self.concat((h1, h2))
cn = self.concat((c1, c2))
return output, (hn, cn)
y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0])
return y1, (h1, c1)
class StackLSTMAscend(nn.Cell):
""" Stack multi-layers LSTM together. """
def __init__(self,
input_size,
hidden_size,
num_layers=1,
has_bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False):
super(StackLSTMAscend, self).__init__()
self.num_layers = num_layers
self.batch_first = batch_first
self.bidirectional = bidirectional
self.transpose = P.Transpose()
# input_size list
input_size_list = [input_size]
for i in range(num_layers - 1):
input_size_list.append(hidden_size * 2)
#weights, bias and layers init
weights_fw = []
weights_bw = []
bias_fw = []
bias_bw = []
stdv = 1 / math.sqrt(hidden_size)
for i in range(num_layers):
# forward weight init
w_np_fw = np.random.uniform(-stdv,
stdv,
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16)
w_fw = Parameter(initializer(Tensor(w_np_fw), w_np_fw.shape), name="w_fw_layer" + str(i))
weights_fw.append(w_fw)
# forward bias init
if has_bias:
b_fw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16)
b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
else:
b_fw = np.zeros((hidden_size * 4)).astype(np.float16)
b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
bias_fw.append(b_fw)
if bidirectional:
# backward weight init
w_np_bw = np.random.uniform(-stdv,
stdv,
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16)
w_bw = Parameter(initializer(Tensor(w_np_bw), w_np_bw.shape), name="w_bw_layer" + str(i))
weights_bw.append(w_bw)
# backward bias init
if has_bias:
b_bw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16)
b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
else:
b_bw = np.zeros((hidden_size * 4)).astype(np.float16)
b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
bias_bw.append(b_bw)
# layer init
self.lstm = LSTM_Ascend(bidirectional=bidirectional)
self.weight_fw = ParameterTuple(tuple(weights_fw))
self.weight_bw = ParameterTuple(tuple(weights_bw))
self.bias_fw = ParameterTuple(tuple(bias_fw))
self.bias_bw = ParameterTuple(tuple(bias_bw))
def construct(self, x, hx):
"""construct"""
x = F.cast(x, mstype.float16)
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
# stack lstm
h, c = hx
hn = cn = None
for i in range(self.num_layers):
if self.bidirectional:
x, (hn, cn) = self.lstm(x,
h[i],
c[i],
self.weight_fw[i],
self.bias_fw[i],
self.weight_bw[i],
self.bias_bw[i])
else:
x, (hn, cn) = self.lstm(x, h[i], c[i], self.weight_fw[i], self.bias_fw[i])
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
x = F.cast(x, mstype.float32)
hn = F.cast(x, mstype.float32)
cn = F.cast(x, mstype.float32)
return x, (hn, cn)
class SentimentNet(nn.Cell): class SentimentNet(nn.Cell):
"""Sentiment network structure.""" """Sentiment network structure."""
@ -145,7 +291,7 @@ class SentimentNet(nn.Cell):
bidirectional=bidirectional, bidirectional=bidirectional,
dropout=0.0) dropout=0.0)
self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
else: elif context.get_context("device_target") == "GPU":
# standard lstm # standard lstm
self.encoder = nn.LSTM(input_size=embed_size, self.encoder = nn.LSTM(input_size=embed_size,
hidden_size=num_hiddens, hidden_size=num_hiddens,
@ -154,8 +300,16 @@ class SentimentNet(nn.Cell):
bidirectional=bidirectional, bidirectional=bidirectional,
dropout=0.0) dropout=0.0)
self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
else:
self.encoder = StackLSTMAscend(input_size=embed_size,
hidden_size=num_hiddens,
num_layers=num_layers,
has_bias=True,
bidirectional=bidirectional)
self.h, self.c = stack_lstm_default_state_ascend(batch_size, num_hiddens, num_layers, bidirectional)
self.concat = P.Concat(1) self.concat = P.Concat(1)
self.squeeze = P.Squeeze(axis=0)
if bidirectional: if bidirectional:
self.decoder = nn.Dense(num_hiddens * 4, num_classes) self.decoder = nn.Dense(num_hiddens * 4, num_classes)
else: else:
@ -167,6 +321,6 @@ class SentimentNet(nn.Cell):
embeddings = self.trans(embeddings, self.perm) embeddings = self.trans(embeddings, self.perm)
output, _ = self.encoder(embeddings, (self.h, self.c)) output, _ = self.encoder(embeddings, (self.h, self.c))
# states[i] size(64,200) -> encoding.size(64,400) # states[i] size(64,200) -> encoding.size(64,400)
encoding = self.concat((output[0], output[499])) encoding = self.concat((self.squeeze(output[0:1:1]), self.squeeze(output[499:500:1])))
outputs = self.decoder(encoding) outputs = self.decoder(encoding)
return outputs return outputs

@ -20,9 +20,10 @@ import os
import numpy as np import numpy as np
from src.config import lstm_cfg as cfg from src.config import lstm_cfg, lstm_cfg_ascend
from src.dataset import convert_to_mindrecord from src.dataset import convert_to_mindrecord
from src.dataset import lstm_create_dataset from src.dataset import lstm_create_dataset
from src.lr_schedule import get_lr
from src.lstm import SentimentNet from src.lstm import SentimentNet
from mindspore import Tensor, nn, Model, context from mindspore import Tensor, nn, Model, context
from mindspore.nn import Accuracy from mindspore.nn import Accuracy
@ -43,8 +44,8 @@ if __name__ == '__main__':
help='the path to save the checkpoint file.') help='the path to save the checkpoint file.')
parser.add_argument('--pre_trained', type=str, default=None, parser.add_argument('--pre_trained', type=str, default=None,
help='the pretrained checkpoint file path.') help='the pretrained checkpoint file path.')
parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'],
help='the target device to run, support "GPU", "CPU". Default: "GPU".') help='the target device to run, support "GPU", "CPU". Default: "Ascend".')
args = parser.parse_args() args = parser.parse_args()
context.set_context( context.set_context(
@ -52,11 +53,23 @@ if __name__ == '__main__':
save_graphs=False, save_graphs=False,
device_target=args.device_target) device_target=args.device_target)
if args.device_target == 'Ascend':
cfg = lstm_cfg_ascend
else:
cfg = lstm_cfg
if args.preprocess == "true": if args.preprocess == "true":
print("============== Starting Data Pre-processing ==============") print("============== Starting Data Pre-processing ==============")
convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path)
embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32)
# DynamicRNN in this network on Ascend platform only support the condition that the shape of input_size
# and hiddle_size is multiples of 16, this problem will be solved later.
if args.device_target == 'Ascend':
pad_num = int(np.ceil(cfg.embed_size / 16) * 16 - cfg.embed_size)
if pad_num > 0:
embedding_table = np.pad(embedding_table, [(0, 0), (0, pad_num)], 'constant')
cfg.embed_size = int(np.ceil(cfg.embed_size / 16) * 16)
network = SentimentNet(vocab_size=embedding_table.shape[0], network = SentimentNet(vocab_size=embedding_table.shape[0],
embed_size=cfg.embed_size, embed_size=cfg.embed_size,
num_hiddens=cfg.num_hiddens, num_hiddens=cfg.num_hiddens,
@ -69,14 +82,25 @@ if __name__ == '__main__':
if args.pre_trained: if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained)) load_param_into_net(network, load_checkpoint(args.pre_trained))
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) if cfg.dynamic_lr:
lr = Tensor(get_lr(global_step=cfg.global_step,
lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max,
warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.num_epochs,
steps_per_epoch=ds_train.get_dataset_size(),
lr_adjust_epoch=cfg.lr_adjust_epoch))
else:
lr = cfg.learning_rate
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
loss_cb = LossMonitor() loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()}) model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)

Loading…
Cancel
Save