!10425 reformat code and remove useless code in textrcnn

From: @chenmai1102
Reviewed-by: @oacjiewen,@guoqi1024
Signed-off-by: @guoqi1024
pull/10425/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a525846718

@ -16,6 +16,7 @@
import argparse import argparse
import os import os
import numpy as np import numpy as np
parser = argparse.ArgumentParser(description='textrcnn') parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.') parser.add_argument('--task', type=str, help='the data preprocess task, including dataset_split.')
parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src') parser.add_argument('--data_dir', type=str, help='the source dataset directory.', default='./data_src')
@ -24,18 +25,18 @@ parser.add_argument('--out_dir', type=str, help='the target dataset directory.',
args = parser.parse_args() args = parser.parse_args()
np.random.seed(2) np.random.seed(2)
def dataset_split(label): def dataset_split(label):
"""dataset_split api""" """dataset_split api"""
# label can be 'pos' or 'neg' # label can be 'pos' or 'neg'
pos_samples = [] pos_samples = []
pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity."+label) pos_file = os.path.join(args.data_dir, "rt-polaritydata", "rt-polarity." + label)
pfhand = open(pos_file, encoding='utf-8') pfhand = open(pos_file, encoding='utf-8')
pos_samples += pfhand.readlines() pos_samples += pfhand.readlines()
pfhand.close() pfhand.close()
perm = np.random.permutation(len(pos_samples)) perm = np.random.permutation(len(pos_samples))
# print(perm[0:int(len(pos_samples)*0.8)]) perm_train = perm[0:int(len(pos_samples) * 0.9)]
perm_train = perm[0:int(len(pos_samples)*0.9)] perm_test = perm[int(len(pos_samples) * 0.9):]
perm_test = perm[int(len(pos_samples)*0.9):]
pos_samples_train = [] pos_samples_train = []
pos_samples_test = [] pos_samples_test = []
for pt in perm_train: for pt in perm_train:
@ -51,10 +52,7 @@ def dataset_split(label):
f.close() f.close()
if __name__ == '__main__': if __name__ == '__main__':
if args.task == "dataset_split": if args.task == "dataset_split":
dataset_split('pos') dataset_split('pos')
dataset_split('neg') dataset_split('neg')
# search(args.q)

@ -32,7 +32,6 @@ from src.textrcnn import textrcnn
set_seed(1) set_seed(1)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='textrcnn') parser = argparse.ArgumentParser(description='textrcnn')
parser.add_argument('--ckpt_path', type=str) parser.add_argument('--ckpt_path', type=str)
@ -46,7 +45,7 @@ if __name__ == '__main__':
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \ network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size) cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)

@ -74,7 +74,7 @@ DEVICE_ID=7 python train.py
bash scripts/run_train.sh bash scripts/run_train.sh
# run evaluating # run evaluating
DEVICE_ID=7 python eval.py --ckpt_path ./ckpt/lstm-10_149.ckpt DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
# or you can use the shell script to evaluate in background # or you can use the shell script to evaluate in background
bash scripts/run_eval.sh bash scripts/run_eval.sh
``` ```

@ -21,6 +21,7 @@ import numpy as np
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
import mindspore.dataset as ds import mindspore.dataset as ds
# preprocess part # preprocess part
def encode_samples(tokenized_samples, word_to_idx): def encode_samples(tokenized_samples, word_to_idx):
""" encode word to index """ """ encode word to index """
@ -78,7 +79,8 @@ def collect_weight(glove_path, vocab, word_to_idx, embed_size):
# wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'), # wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
# binary=False, encoding='utf-8') # binary=False, encoding='utf-8')
wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \ wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
'GoogleNews-vectors-negative300.bin'), binary=True) 'GoogleNews-vectors-negative300.bin'),
binary=True)
weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32) weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
idx_to_word = {i + 1: word for i, word in enumerate(vocab)} idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
@ -140,7 +142,7 @@ def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
preprocess(data_path, glove_path, embed_size) preprocess(data_path, glove_path, embed_size)
np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np) np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\ print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",
weight_np.shape, "type:", train_labels.dtype) weight_np.shape, "type:", train_labels.dtype)
# write mindrecord # write mindrecord
schema_json = {"id": {"type": "int32"}, schema_json = {"id": {"type": "int32"},

@ -22,8 +22,10 @@ from mindspore.common.parameter import Parameter
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
class textrcnn(nn.Cell): class textrcnn(nn.Cell):
"""class textrcnn""" """class textrcnn"""
def __init__(self, weight, vocab_size, cell, batch_size): def __init__(self, weight, vocab_size, cell, batch_size):
super(textrcnn, self).__init__() super(textrcnn, self).__init__()
self.num_hiddens = 512 self.num_hiddens = 512
@ -89,7 +91,6 @@ class textrcnn(nn.Cell):
self.tanh = P.Tanh() self.tanh = P.Tanh()
self.sigmoid = P.Sigmoid() self.sigmoid = P.Sigmoid()
self.slice = P.Slice() self.slice = P.Slice()
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
def construct(self, x): def construct(self, x):
"""class construction""" """class construction"""

@ -31,7 +31,6 @@ from src.dataset import convert_to_mindrecord
from src.textrcnn import textrcnn from src.textrcnn import textrcnn
from src.utils import get_lr from src.utils import get_lr
set_seed(2) set_seed(2)
if __name__ == '__main__': if __name__ == '__main__':
@ -56,7 +55,7 @@ if __name__ == '__main__':
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0], \ network = textrcnn(weight=Tensor(embedding_table), vocab_size=embedding_table.shape[0],
cell=cfg.cell, batch_size=cfg.batch_size) cell=cfg.cell, batch_size=cfg.batch_size)
ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True) ds_train = create_dataset(cfg.preprocess_path, cfg.batch_size, cfg.num_epochs, True)
@ -74,7 +73,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3") model = Model(network, loss, opt, {'acc': Accuracy()}, amp_level="O3")
print("============== Starting Training ==============") print("============== Starting Training ==============")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, \ config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb]) model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb])

Loading…
Cancel
Save