|
|
|
@ -71,21 +71,35 @@ class FastTextInferCell(nn.Cell):
|
|
|
|
|
|
|
|
|
|
return predicted_idx
|
|
|
|
|
|
|
|
|
|
def load_infer_dataset(batch_size, datafile):
|
|
|
|
|
def load_infer_dataset(batch_size, datafile, bucket):
|
|
|
|
|
"""data loader for infer"""
|
|
|
|
|
data_set = ds.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
|
|
|
|
|
|
|
|
|
type_cast_op = deC.TypeCast(mstype.int32)
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")
|
|
|
|
|
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
|
|
|
|
|
def batch_per_bucket(bucket_length, input_file):
|
|
|
|
|
input_file = input_file + '/test_dataset_bs_' + str(bucket_length) + '.mindrecord'
|
|
|
|
|
if not input_file:
|
|
|
|
|
raise FileNotFoundError("input file parameter must not be empty.")
|
|
|
|
|
|
|
|
|
|
data_set = ds.MindDataset(input_file,
|
|
|
|
|
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
|
|
|
|
type_cast_op = deC.TypeCast(mstype.int32)
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")
|
|
|
|
|
|
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=False)
|
|
|
|
|
return data_set
|
|
|
|
|
for i, _ in enumerate(bucket):
|
|
|
|
|
bucket_len = bucket[i]
|
|
|
|
|
ds_per = batch_per_bucket(bucket_len, datafile)
|
|
|
|
|
if i == 0:
|
|
|
|
|
data_set = ds_per
|
|
|
|
|
else:
|
|
|
|
|
data_set = data_set + ds_per
|
|
|
|
|
|
|
|
|
|
return data_set
|
|
|
|
|
|
|
|
|
|
def run_fasttext_infer():
|
|
|
|
|
"""run infer with FastText"""
|
|
|
|
|
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
|
|
|
|
|
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path, bucket=config.test_buckets)
|
|
|
|
|
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
|
|
|
|
|
|
|
|
|
parameter_dict = load_checkpoint(args.model_ckpt)
|
|
|
|
@ -107,7 +121,15 @@ def run_fasttext_infer():
|
|
|
|
|
|
|
|
|
|
from sklearn.metrics import accuracy_score, classification_report
|
|
|
|
|
target_sens = np.array(target_sens).flatten()
|
|
|
|
|
merge_target_sens = []
|
|
|
|
|
for target_sen in target_sens:
|
|
|
|
|
merge_target_sens.extend(target_sen)
|
|
|
|
|
target_sens = merge_target_sens
|
|
|
|
|
predictions = np.array(predictions).flatten()
|
|
|
|
|
merge_predictions = []
|
|
|
|
|
for prediction in predictions:
|
|
|
|
|
merge_predictions.extend(prediction)
|
|
|
|
|
predictions = merge_predictions
|
|
|
|
|
acc = accuracy_score(target_sens, predictions)
|
|
|
|
|
|
|
|
|
|
result_report = classification_report(target_sens, predictions, target_names=target_label1)
|
|
|
|
|