commit
4b436d61a8
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset helpers api"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
|
||||
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
|
||||
parser.add_argument('--out_dir', type=str, help='the target dataset directory.', default='./data')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def dataset_split(label):
|
||||
"""dataset_split api"""
|
||||
# label can be 'pos' or 'neg'
|
||||
pos_samples = []
|
||||
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label)
|
||||
pfhand = open(pos_file, encoding='utf-8')
|
||||
pos_samples += pfhand.readlines()
|
||||
pfhand.close()
|
||||
perm = np.random.permutation(len(pos_samples))
|
||||
# print(perm[0:int(len(pos_samples)*0.8)])
|
||||
perm_train = perm[0:int(len(pos_samples)*0.9)]
|
||||
perm_test = perm[int(len(pos_samples)*0.9):]
|
||||
pos_samples_train = []
|
||||
pos_samples_test = []
|
||||
for pt in perm_train:
|
||||
pos_samples_train.append(pos_samples[pt])
|
||||
for pt in perm_test:
|
||||
pos_samples_test.append(pos_samples[pt])
|
||||
f = open(os.path.join(args.out_dir, 'train', label), "w")
|
||||
f.write(''.join(pos_samples_train))
|
||||
f.close()
|
||||
|
||||
f = open(os.path.join(args.out_dir, 'test', label), "w")
|
||||
f.write(''.join(pos_samples_test))
|
||||
f.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.task == "dataset_split":
|
||||
dataset_split('pos')
|
||||
dataset_split('neg')
|
||||
|
||||
# search(args.q)
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model evaluation script"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.textrcnn import textrcnn
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--ckpt_path', type=str)
|
||||
args = parser.parse_args()
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
loss_cb = LossMonitor()
|
||||
print("============== Starting Testing ==============")
|
||||
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, 1, False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss, opt, metrics={'acc': Accuracy()}, amp_level='O3')
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
@ -0,0 +1,2 @@
|
||||
DEVICE_ID=7 python train.py
|
||||
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-1_149.ckpt
|
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
python ${BASEPATH}/../train.py > ./train.log 2>&1 &
|
@ -0,0 +1,38 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# LSTM CONFIG
|
||||
textrcnn_cfg = edict({
|
||||
'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
|
||||
'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
|
||||
'num_epochs': 10,
|
||||
'batch_size': 64,
|
||||
'cell': 'lstm',
|
||||
'opt': 'adam',
|
||||
'ckpt_folder_path': './ckpt',
|
||||
'preprocess_path': './preprocess',
|
||||
'preprocess': 'false',
|
||||
'data_path': './data/',
|
||||
'lr': 1e-3,
|
||||
'emb_path': './word2vec',
|
||||
'embed_size': 300,
|
||||
'save_checkpoint_steps': 149,
|
||||
'keep_checkpoint_max': 10,
|
||||
'momentum': 0.9
|
||||
})
|
@ -0,0 +1,179 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dataset api"""
|
||||
import os
|
||||
from itertools import chain
|
||||
import gensim
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset as ds
|
||||
|
||||
# preprocess part
|
||||
def encode_samples(tokenized_samples, word_to_idx):
|
||||
""" encode word to index """
|
||||
features = []
|
||||
for sample in tokenized_samples:
|
||||
feature = []
|
||||
for token in sample:
|
||||
if token in word_to_idx:
|
||||
feature.append(word_to_idx[token])
|
||||
else:
|
||||
feature.append(0)
|
||||
features.append(feature)
|
||||
return features
|
||||
|
||||
|
||||
def pad_samples(features, maxlen=50, pad=0):
|
||||
""" pad all features to the same length """
|
||||
padded_features = []
|
||||
for feature in features:
|
||||
if len(feature) >= maxlen:
|
||||
padded_feature = feature[:maxlen]
|
||||
else:
|
||||
padded_feature = feature
|
||||
while len(padded_feature) < maxlen:
|
||||
padded_feature.append(pad)
|
||||
padded_features.append(padded_feature)
|
||||
return padded_features
|
||||
|
||||
|
||||
def read_imdb(path, seg='train'):
|
||||
""" read imdb dataset """
|
||||
pos_or_neg = ['pos', 'neg']
|
||||
data = []
|
||||
for label in pos_or_neg:
|
||||
|
||||
f = os.path.join(path, seg, label)
|
||||
rf = open(f, 'r')
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
if label == 'pos':
|
||||
data.append([line, 1])
|
||||
elif label == 'neg':
|
||||
data.append([line, 0])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def tokenizer(text):
|
||||
return [tok.lower() for tok in text.split(' ')]
|
||||
|
||||
|
||||
def collect_weight(glove_path, vocab, word_to_idx, embed_size):
|
||||
""" collect weight """
|
||||
vocab_size = len(vocab)
|
||||
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
|
||||
# binary=False, encoding='utf-8')
|
||||
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
|
||||
'GoogleNews-vectors-negative300.bin'), binary=True)
|
||||
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
|
||||
|
||||
idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
|
||||
idx_to_word[0] = '<unk>'
|
||||
|
||||
for i in range(len(wvmodel.index2word)):
|
||||
try:
|
||||
index = word_to_idx[wvmodel.index2word[i]]
|
||||
except KeyError:
|
||||
continue
|
||||
weight_np[index, :] = wvmodel.get_vector(
|
||||
idx_to_word[word_to_idx[wvmodel.index2word[i]]])
|
||||
return weight_np
|
||||
|
||||
|
||||
def preprocess(data_path, glove_path, embed_size):
|
||||
""" preprocess the train and test data """
|
||||
train_data = read_imdb(data_path, 'train')
|
||||
test_data = read_imdb(data_path, 'test')
|
||||
|
||||
train_tokenized = []
|
||||
test_tokenized = []
|
||||
for review, _ in train_data:
|
||||
train_tokenized.append(tokenizer(review))
|
||||
for review, _ in test_data:
|
||||
test_tokenized.append(tokenizer(review))
|
||||
|
||||
vocab = set(chain(*train_tokenized))
|
||||
vocab_size = len(vocab)
|
||||
print("vocab_size: ", vocab_size)
|
||||
|
||||
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
|
||||
word_to_idx['<unk>'] = 0
|
||||
|
||||
train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
|
||||
train_labels = np.array([score for _, score in train_data]).astype(np.int32)
|
||||
test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
|
||||
test_labels = np.array([score for _, score in test_data]).astype(np.int32)
|
||||
|
||||
weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
|
||||
return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
|
||||
|
||||
|
||||
def get_imdb_data(labels_data, features_data):
|
||||
data_list = []
|
||||
for i, (label, feature) in enumerate(zip(labels_data, features_data)):
|
||||
data_json = {"id": i,
|
||||
"label": int(label),
|
||||
"feature": feature.reshape(-1)}
|
||||
data_list.append(data_json)
|
||||
return data_list
|
||||
|
||||
|
||||
def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
|
||||
""" convert imdb dataset to mindrecord """
|
||||
|
||||
num_shard = 4
|
||||
train_features, train_labels, test_features, test_labels, weight_np, _ = \
|
||||
preprocess(data_path, glove_path, embed_size)
|
||||
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
|
||||
|
||||
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
|
||||
weight_np.shape, "type:", train_labels.dtype)
|
||||
# write mindrecord
|
||||
schema_json = {"id": {"type": "int32"},
|
||||
"label": {"type": "int32"},
|
||||
"feature": {"type": "int32", "shape": [-1]}}
|
||||
|
||||
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
|
||||
data = get_imdb_data(train_labels, train_features)
|
||||
writer.add_schema(schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
|
||||
data = get_imdb_data(test_labels, test_features)
|
||||
writer.add_schema(schema_json, "nlp_schema")
|
||||
writer.add_index(["id", "label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_dataset(base_path, batch_size, num_epochs, is_train):
|
||||
"""Create dataset for training."""
|
||||
columns_list = ["feature", "label"]
|
||||
num_consumer = 4
|
||||
|
||||
if is_train:
|
||||
path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
|
||||
else:
|
||||
path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
|
||||
|
||||
data_set = ds.MindDataset(path, columns_list, num_consumer)
|
||||
ds.config.set_seed(1)
|
||||
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
|
||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||
return data_set
|
@ -0,0 +1,196 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model textrcnn"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
class textrcnn(nn.Cell):
|
||||
"""class textrcnn"""
|
||||
def __init__(self, weight, vocab_size, cell, batch_size):
|
||||
super(textrcnn, self).__init__()
|
||||
self.num_hiddens = 512
|
||||
self.embed_size = 300
|
||||
self.num_classes = 2
|
||||
self.batch_size = batch_size
|
||||
k = (1 / self.num_hiddens) ** 0.5
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
|
||||
self.embedding.embedding_table.requires_grad = False
|
||||
self.cell = cell
|
||||
|
||||
self.cast = P.Cast()
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
if cell == "lstm":
|
||||
self.lstm = P.DynamicRNN(forget_bias=0.0)
|
||||
self.w1_fw = Parameter(
|
||||
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
|
||||
np.float16), name="w1_fw")
|
||||
self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
|
||||
name="b1_fw")
|
||||
self.w1_bw = Parameter(
|
||||
np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
|
||||
np.float16), name="w1_bw")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
|
||||
name="b1_bw")
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
if cell == "vanilla":
|
||||
self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
|
||||
self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
|
||||
self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
|
||||
if cell == "gru":
|
||||
self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reduce_max = P.ReduceMax()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
|
||||
self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
|
||||
self.concat0 = P.Concat(0)
|
||||
self.concat2 = P.Concat(2)
|
||||
self.concat1 = P.Concat(1)
|
||||
self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
self.mydense = nn.Dense(self.num_hiddens, 2)
|
||||
self.drop_out = nn.Dropout(keep_prob=0.7)
|
||||
self.tanh = P.Tanh()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.slice = P.Slice()
|
||||
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
|
||||
|
||||
def construct(self, x):
|
||||
"""class construction"""
|
||||
# x: bs, sl
|
||||
output_fw = x
|
||||
output_bw = x
|
||||
|
||||
if self.cell == "vanilla":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
||||
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
|
||||
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
||||
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
||||
|
||||
for i in range(F.shape(x)[0] - 2, -1, -1):
|
||||
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
||||
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
|
||||
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
if self.cell == "gru":
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h_fw = self.cast(self.h1, mstype.float16)
|
||||
|
||||
h_x_fw = self.concat1((h_fw, x[0, :, :]))
|
||||
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
|
||||
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
|
||||
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
|
||||
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
|
||||
output_fw = self.expand_dims(h_fw, 0)
|
||||
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
h_x_fw = self.concat1((h_fw, x[i, :, :]))
|
||||
r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
|
||||
z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
|
||||
h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
|
||||
h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
|
||||
h_after_expand_fw = self.expand_dims(h_fw, 0)
|
||||
output_fw = self.concat((output_fw, h_after_expand_fw))
|
||||
output_fw = self.cast(output_fw, mstype.float16)
|
||||
|
||||
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
||||
|
||||
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
|
||||
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
||||
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
|
||||
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
|
||||
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
|
||||
output_bw = self.expand_dims(h_bw, 0)
|
||||
for i in range(F.shape(x)[0] - 2, -1, -1):
|
||||
h_x_bw = self.concat1((h_bw, x[i, :, :]))
|
||||
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
||||
z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
|
||||
h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
|
||||
h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
|
||||
h_after_expand_bw = self.expand_dims(h_bw, 0)
|
||||
output_bw = self.concat((h_after_expand_bw, output_bw))
|
||||
output_bw = self.cast(output_bw, mstype.float16)
|
||||
if self.cell == 'lstm':
|
||||
x = self.embedding(x) # bs, sl, emb_size
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
||||
x = self.drop_out(x) # sl,bs, emb_size
|
||||
|
||||
h1_fw_init = self.h1 # bs, num_hidden
|
||||
c1_fw_init = self.c1 # bs, num_hidden
|
||||
|
||||
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
|
||||
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
||||
|
||||
h1_bw_init = self.h1 # bs, num_hidden
|
||||
c1_bw_init = self.c1 # bs, num_hidden
|
||||
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
|
||||
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
|
||||
|
||||
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
|
||||
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
|
||||
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
||||
output = self.cast(output, mstype.float16)
|
||||
|
||||
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
|
||||
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
||||
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
||||
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
||||
output = self.reduce_max(output, 0) # bs, num_hidden
|
||||
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
||||
return outputs
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model train script"""
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.dataset import convert_to_mindrecord
|
||||
from src.textrcnn import textrcnn
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if cfg.preprocess == 'true':
|
||||
print("============== Starting Data Pre-processing ==============")
|
||||
if os.path.exists(cfg.preprocess_path):
|
||||
shutil.rmtree(cfg.preprocess_path)
|
||||
os.mkdir(cfg.preprocess_path)
|
||||
convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
|
||||
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \
|
||||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
if cfg.opt == "adam":
|
||||
opt = nn.Adam(params=network.trainable_params(), learning_rate=cfg.lr)
|
||||
elif cfg.opt == "momentum":
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
loss_cb = LossMonitor()
|
||||
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
|
||||
model.train(cfg.num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
|
Loading…
Reference in new issue