!4789 Add EditDistance op for GE.

Merge pull request !4789 from liuxiao93/Add-EditDistance-op-for-GE
pull/4789/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2ac410e90d

@ -190,6 +190,7 @@ constexpr const char kNameSquareSumAll[] = "SquareSumAll";
constexpr const char kNameAscendQuant[] = "Quant";
constexpr const char kNameAscendDequant[] = "Dequant";
constexpr const char kNameReverseSequence[] = "ReverseSequence";
constexpr const char kNameEditDistance[] = "EditDistance";
constexpr const char kNameCase[] = "Case";
class OpAdapterMap {

@ -87,4 +87,12 @@ ATTR_MAP(ReverseSequence) = {{"seq_dim", ATTR_DESC(seq_dim, AnyTraits<int>())},
{"batch_dim", ATTR_DESC(batch_dim, AnyTraits<int>())}};
OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence))
// EditDistance
INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(hypothesis_values)},
{3, INPUT_DESC(hypothesis_shape)}, {4, INPUT_DESC(truth_indices)},
{5, INPUT_DESC(truth_values)}, {6, INPUT_DESC(truth_shape)}};
ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}};
OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance))
} // namespace mindspore::transform

@ -54,5 +54,8 @@ DECLARE_OP_ADAPTER(Data)
DECLARE_OP_ADAPTER(ReverseSequence)
DECLARE_OP_USE_OUTPUT(ReverseSequence)
DECLARE_OP_ADAPTER(EditDistance)
DECLARE_OP_USE_OUTPUT(EditDistance)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_

@ -23,7 +23,6 @@ from .acos_grad import _acos_grad_tbe
from .acosh import _acosh_tbe
from .acosh_grad import _acosh_grad_tbe
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
from .add import _add_tbe
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
from .add_n import _add_n_tbe
from .accumulate_n_v2 import _accumulate_n_v2_tbe

@ -1,37 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Add op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
add_op_info = TBERegOp("Add") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("add.so") \
.compute_cost(10) \
.kernel_name("add") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(add_op_info)
def _add_tbe():
"""Add TBE register"""
return

@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
@ -92,6 +92,7 @@ from .sparse_ops import SparseToDense
__all__ = [
'ReverseSequence',
'EditDistance',
'CropAndResize',
'TensorAdd',
'Argmax',

@ -3470,6 +3470,93 @@ class ReverseSequence(PrimitiveWithInfer):
return x
class EditDistance(PrimitiveWithInfer):
"""
Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences.
Args:
normalize (bool): If True, edit distances are normalized by length of truth. Default: True.
Inputs:
- **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type.
The shape of tensor is :math:`(N, R)`.
- **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor.
Must be 1-D vector with length of N.
- **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor.
Must be R-length vector with int64 data type. Only constant value is allowed.
- **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
The shape of tensor is :math:`(M, R)`.
- **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
- **truth_shape** (Tensor) - The values of the truth list SparseTensor.
Must be R-length vector with int64 data type. Only constant value is allowed.
Outputs:
Tensor, a dense tensor with rank `R-1` and float32 data type.
Examples:
>>> class EditDistance(nn.Cell):
>>> def __init__(self, hypothesis_shape, truth_shape, normalize=True):
>>> super(EditDistance, self).__init__()
>>> self.edit_distance = P.EditDistance(normalize)
>>> self.hypothesis_shape = hypothesis_shape
>>> self.truth_shape = truth_shape
>>>
>>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
>>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
>>> truth_indices, truth_values, self.truth_shape)
>>>
>>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64))
>>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32))
>>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64))
>>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64))
>>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32))
>>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64))
>>> edit_distance = EditDistance(hypothesis_shape, truth_shape)
>>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
>>> [[1.0, 1.0], [1.0, 1.0]]
"""
@prim_attr_register
def __init__(self, normalize=True):
"""init EditDistance"""
self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
validator.check_const_input('hypothesis_shape', h_shape['value'], self.name)
validator.check_const_input('truth_shape', truth_shape['value'], self.name)
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name)
validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name)
validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name)
validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name)
validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name)
validator.check("hypothesis_values shape", h_values['shape'][0],
"hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name)
validator.check("hypothesis_shape", h_shape['shape'][0],
"hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name)
validator.check("truth_values shape", truth_values['shape'][0],
"truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name)
validator.check("hypothesis_shape", h_shape['shape'][0],
"truth_shape", truth_shape['shape'][0], Rel.EQ, self.name)
hypothesis_shape_v = h_shape['value'].asnumpy()
truth_shape_v = truth_shape['value'].asnumpy()
out_shape_rank = len(hypothesis_shape_v) - 1
out_shape = []
for i in range(out_shape_rank):
out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i]))
return {'shape': tuple(out_shape),
'dtype': mstype.tensor_type(mstype.float32),
'value': None}
class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.

@ -684,6 +684,18 @@ class ParallelConcatNet(nn.Cell):
return self.parallel_concat((x1, x2))
class EditDistance(nn.Cell):
def __init__(self, hypothesis_shape, truth_shape, normalize=True):
super(EditDistance, self).__init__()
self.edit_distance = P.EditDistance(normalize)
self.hypothesis_shape = hypothesis_shape
self.truth_shape =truth_shape
def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
truth_indices, truth_values, self.truth_shape)
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@ -1978,6 +1990,15 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[3, 3]]}),
('EditDistance', {
'block': EditDistance(Tensor(np.array([1, 1, 2]).astype(np.int64)),
Tensor(np.array([2, 2, 2]).astype(np.int64))),
'desc_inputs': [Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)),
Tensor(np.array([1, 2, 3]).astype(np.float32)),
Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)),
Tensor(np.array([1, 3, 2, 1]).astype(np.float32))],
'skip': ['backward'],
}),
('LinSpace', {
'block': inner.LinSpace(),
'desc_inputs': [Tensor([5, 5.5], mstype.float32),

Loading…
Cancel
Save