From d3e561c5a79556e9ffbe4dc4c69697dca5dab75d Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 25 Dec 2020 11:05:47 +0800 Subject: [PATCH] add export file --- model_zoo/official/nlp/fasttext/export.py | 100 ++++++++++++++++++ .../official/nlp/fasttext/src/dataset.py | 28 ++--- 2 files changed, 114 insertions(+), 14 deletions(-) create mode 100644 model_zoo/official/nlp/fasttext/export.py diff --git a/model_zoo/official/nlp/fasttext/export.py b/model_zoo/official/nlp/fasttext/export.py new file mode 100644 index 0000000000..532558d6b4 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/export.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into models""" + +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +import mindspore.ops.operations as P +from mindspore import context +from mindspore.train.serialization import load_checkpoint, export, load_param_into_net +from src.fasttext_model import FastText + +parser = argparse.ArgumentParser(description='fasttexts') +parser.add_argument('--device_target', type=str, choices=["Ascend", "GPU", "CPU"], + default='Ascend', help='Device target') +parser.add_argument('--device_id', type=int, default=0, help='Device id') +parser.add_argument('--ckpt_file', type=str, required=True, help='Checkpoint file path') +parser.add_argument('--file_name', type=str, default='fasttexts', help='Output file name') +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', + help='Output file format') +parser.add_argument('--data_name', type=str, required=True, default='ag', + help='Dataset name. eg. ag, dbpedia, yelp_p') +args = parser.parse_args() + +if args.data_name == "ag": + from src.config import config_ag as config + target_label1 = ['0', '1', '2', '3'] +elif args.data_name == 'dbpedia': + from src.config import config_db as config + target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13'] +elif args.data_name == 'yelp_p': + from src.config import config_yelpp as config + target_label1 = ['0', '1'] + +context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend") + +class FastTextInferExportCell(nn.Cell): + """ + Encapsulation class of FastText network infer. + + Args: + network (nn.Cell): FastText model. + + Returns: + Tuple[Tensor, Tensor], predicted_ids + """ + def __init__(self, network): + super(FastTextInferExportCell, self).__init__(auto_prefix=False) + self.network = network + self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True) + self.log_softmax = nn.LogSoftmax(axis=1) + + def construct(self, src_tokens, src_tokens_lengths): + """construct fasttext infer cell""" + prediction = self.network(src_tokens, src_tokens_lengths) + predicted_idx = self.log_softmax(prediction) + predicted_idx, _ = self.argmax(predicted_idx) + + return predicted_idx + +def run_fasttext_export(): + """export function""" + fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class) + parameter_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(fasttext_model, parameter_dict) + ft_infer = FastTextInferExportCell(fasttext_model) + + if args.data_name == "ag": + src_tokens_shape = [config.batch_size, 467] + src_tokens_length_shape = [config.batch_size, 1] + elif args.data_name == 'dbpedia': + src_tokens_shape = [config.batch_size, 1120] + src_tokens_length_shape = [config.batch_size, 1] + elif args.data_name == 'yelp_p': + src_tokens_shape = [config.batch_size, 2955] + src_tokens_length_shape = [config.batch_size, 1] + + file_name = args.file_name + '_' + args.data_name + src_tokens = Tensor(np.ones((src_tokens_shape)).astype(np.int32)) + src_tokens_length = Tensor(np.ones((src_tokens_length_shape)).astype(np.int32)) + export(ft_infer, src_tokens, src_tokens_length, file_name=file_name, file_format=args.file_format) + +if __name__ == '__main__': + run_fasttext_export() diff --git a/model_zoo/official/nlp/fasttext/src/dataset.py b/model_zoo/official/nlp/fasttext/src/dataset.py index c3e0b3c0fa..7135bd4590 100644 --- a/model_zoo/official/nlp/fasttext/src/dataset.py +++ b/model_zoo/official/nlp/fasttext/src/dataset.py @@ -53,14 +53,14 @@ class FastTextDataPreProcess(): self.ngram = ngram self.text_greater = '>' self.text_less = '<' - self.ngram2idx = dict() - self.idx2gram = dict() + self.word2vec = dict() + self.vec2words = dict() self.non_str = '\\' self.end_string = ['.', '?', '!'] - self.ngram2idx['PAD'] = 0 - self.idx2gram[0] = 'PAD' - self.ngram2idx['UNK'] = 1 - self.idx2gram[1] = 'UNK' + self.word2vec['PAD'] = 0 + self.vec2words[0] = 'PAD' + self.word2vec['UNK'] = 1 + self.vec2words[1] = 'UNK' self.str_html = re.compile(r'<[^>]+>') def load(self): @@ -144,13 +144,13 @@ class FastTextDataPreProcess(): for l in range(train_dataset_list_length): bucket_length = self._get_bucket_length(train_dataset_list[l][0], self.buckets) while len(train_dataset_list[l][0]) < bucket_length: - train_dataset_list[l][0].append(self.ngram2idx['PAD']) + train_dataset_list[l][0].append(self.word2vec['PAD']) train_dataset_list[l][1] = len(train_dataset_list[l][0]) # pad test dataset for j in range(test_dataset_list_length): test_bucket_length = self._get_bucket_length(test_dataset_list[j][0], self.test_bucket) while len(test_dataset_list[j][0]) < test_bucket_length: - test_dataset_list[j][0].append(self.ngram2idx['PAD']) + test_dataset_list[j][0].append(self.word2vec['PAD']) test_dataset_list[j][1] = len(test_dataset_list[j][0]) train_example_data = [] @@ -173,7 +173,7 @@ class FastTextDataPreProcess(): for key in self.test_feature_dict: if key == test_example_data[h]['src_tokens_length']: self.test_feature_dict[key].append(test_example_data[h]) - print("train vocab size is ", len(self.ngram2idx)) + print("train vocab size is ", len(self.word2vec)) return self.train_feature_dict, self.test_feature_dict @@ -210,13 +210,13 @@ class FastTextDataPreProcess(): if train_mode is True: for ngms in bo_ngrams: - idx = self.ngram2idx.get(ngms) + idx = self.word2vec.get(ngms) if idx is None: - idx = len(self.ngram2idx) - self.ngram2idx[ngms] = idx - self.idx2gram[idx] = ngms + idx = len(self.word2vec) + self.word2vec[ngms] = idx + self.vec2words[idx] = ngms - processed_out = [self.ngram2idx[ng] if ng in self.ngram2idx else self.ngram2idx['UNK'] for ng in bo_ngrams] + processed_out = [self.word2vec[ng] if ng in self.word2vec else self.word2vec['UNK'] for ng in bo_ngrams] return processed_out