Add ExtractImagePatches for new backend.

pull/6071/head
liuxiao93 4 years ago
parent a084196300
commit 7235dca0e3

@ -204,6 +204,8 @@ from .batch_to_space import _batch_to_space_tbe
from .space_to_batch import _space_to_batch_tbe
from .depth_to_space import _depth_to_space_tbe
from .space_to_depth import _space_to_depth_tbe
from .extract_image_patches import _extract_image_patches_tbe
from .sort import _sort_tbe
from .floor import _floor_tbe
from .ceil import _ceil_tbe
from .log1p import _log1p_tbe

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExtractImagePatches op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
extract_image_patches_op_info = TBERegOp("ExtractImagePatches") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("extract_image_patches.so") \
.compute_cost(10) \
.kernel_name("extract_image_patches") \
.partial_flag(True) \
.attr("ksizes", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("rates", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_NHWC) \
.dtype_format(DataType.I8_5HD, DataType.I8_NHWC) \
.dtype_format(DataType.U8_5HD, DataType.U8_NHWC) \
.get_op_info()
@op_info_register(extract_image_patches_op_info)
def _extract_image_patches_tbe():
"""ExtractImagePatches TBE register"""
return

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sort op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sort_op_info = TBERegOp("Sort") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("sort.so") \
.compute_cost(10) \
.kernel_name("sort") \
.partial_flag(True) \
.attr("axis", "optional", "int", "all", "-1") \
.attr("descending", "optional", "bool", "all", "false") \
.input(0, "x", False, "required", "all") \
.output(0, "y1", False, "required", "all") \
.output(1, "y2", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(sort_op_info)
def _sort_tbe():
"""Sort TBE register"""
return

@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
@ -93,6 +93,7 @@ from .sparse_ops import SparseToDense
__all__ = [
'Unique',
'ReverseSequence',
'Sort',
'EditDistance',
'CropAndResize',
'TensorAdd',

@ -3670,6 +3670,44 @@ class TransShape(PrimitiveWithInfer):
'value': None}
class Sort(PrimitiveWithInfer):
"""
Sorts the elements of the input tensor along a given dimension in ascending order by value.
Args:
axis (int): The dimension to sort along. Default: -1.
descending (bool): Controls the sorting order. If descending is True then the elements
are sorted in descending order by value. Default: False.
Inputs:
- **x** (Tensor) - The input to sort, with float16 or float32 data type.
Outputs:
- **y1** (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
- **y2** (Tensor) - The indices of the elements in the original input tensor. Data type is int32.
Examples:
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
>>> sort = P.Sort()
>>> sort(x)
>>> ([[1.0, 2.0, 8.0], [3.0, 5.0, 9.0], [4.0, 6.0 ,7.0]],
[[2, 1, 0], [2, 0, 1], [0, 1, 2]])
"""
@prim_attr_register
def __init__(self, axis=-1, descending=False):
"""init Sort"""
self.axis = validator.check_value_type("axis", axis, [int], self.name)
self.descending = validator.check_value_type("descending", descending, [bool], self.name)
def infer_shape(self, x_shape):
return x_shape, x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name)
return x_dtype, mstype.tensor_type(mstype.int32)
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices.

Loading…
Cancel
Save