!7030 add Meshgrid ops for aicpu

Merge pull request !7030 from yanzhenxiang2020/br_meshgrid
pull/7030/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit cfb131b844

@ -38,10 +38,10 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
return;
}
// For compatibility with the current framework
if (op_name == kPrint || op_name == kGetNext || op_name == kPack) {
if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
if (op_name == kPrint || op_name == kPack) {
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));

@ -29,6 +29,7 @@ constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kPrint = "Print";
constexpr auto kPack = "Pack";
constexpr auto kMeshgrid = "Meshgrid";
constexpr auto kOutputTypes = "output_types";
constexpr auto kOutputShapes = "output_shapes";
constexpr auto kChannelName = "channel_name";
@ -46,7 +47,7 @@ constexpr auto kEditDistance = "EditDistance";
constexpr auto kGatherD = "GatherD";
constexpr auto kIdentity = "Identity";
constexpr auto kCustRunApi = "RunCpuKernel";
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity};
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kGatherD, kIdentity, kMeshgrid};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message

@ -55,3 +55,4 @@ from .fused_sparse_adam import _fused_sparse_adam_aicpu
from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
from .meshgrid import _meshgrid_aicpu

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Meshgrid op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
meshgrid_op_info = AiCPURegOp("Meshgrid") \
.fusion_type("OPAQUE") \
.attr("indexing", "str") \
.input(0, "x", "dynamic") \
.output(0, "y", "dynamic") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(meshgrid_op_info)
def _meshgrid_aicpu():
"""Meshgrid AiCPU register"""
return

@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
@ -110,6 +110,7 @@ __all__ = [
'MatMul',
'BatchMatMul',
'Mul',
'Meshgrid',
'Pow',
'Exp',
'Expm1',

@ -3509,6 +3509,103 @@ class BroadcastTo(PrimitiveWithInfer):
return x_dtype
class Meshgrid(PrimitiveWithInfer):
"""
Generates coordinate matrices from given coordinate tensors.
Given N one-dimensional coordinate tensors, returns a list outputs of N N-D
coordinate tensors for evaluating expressions on an N-D grid.
Args:
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
When the indexing argument is set to 'xy' (the default),
the broadcasting instructions for the first two dimensions are swapped.
Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
The length of input_x should be greater than 1
Outputs:
Tensors, A Tuple of N N-D Tensor objects.
Examples:
>>> x = np.array([1, 2, 3, 4]).astype(np.int32)
>>> y = np.array([5, 6, 7]).astype(np.int32)
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> inputs = (x, y, z)
>>> meshgrid = P.Meshgrid(indexing="xy")
>>> meshgrid(inputs)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]],
[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]],
[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]]]),
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[5, 5, 5, 5, 5],
[5, 5, 5, 5, 5],
[5, 5, 5, 5, 5],
[5, 5, 5, 5, 5]],
[[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6]],
[[7, 7, 7, 7, 7],
[7, 7, 7, 7, 7],
[7, 7, 7, 7, 7],
[7, 7, 7, 7, 7]]]),
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2]],
[[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2]],
[[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2]]]))
"""
@prim_attr_register
def __init__(self, indexing="xy"):
"""Init Meshgrid"""
validator.check_value_type("indexing", indexing, (str), self.name)
if indexing not in ("xy", "ij"):
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
self.indexing = indexing
def infer_shape(self, x_shape):
validator.check_value_type("shape", x_shape, [tuple, list], self.name)
validator.check_integer("len of input_x", len(x_shape), 2, Rel.GE, self.name)
n = len(x_shape)
shape_0 = []
for s in x_shape:
validator.check_integer('each_input_rank', len(s), 1, Rel.EQ, self.name)
shape_0.append(s[0])
if self.indexing == "xy":
shape_0[0], shape_0[1] = shape_0[1], shape_0[0]
out_shape = tuple(tuple(shape_0) for _ in range(n))
return out_shape
def infer_dtype(self, x_type):
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
n = len(x_type)
for i in range(1, n):
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
return x_type
class InplaceUpdate(PrimitiveWithInfer):
r"""
Updates specified rows with values in `v`.

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save