add Padding op

pull/4243/head
wuxuejian 5 years ago
parent d6fcf731ec
commit c713382798

@ -15,6 +15,7 @@
"""aicpu ops"""
from .init_data_set_queue import _init_data_set_queue_aicpu
from .embedding_lookup import _embedding_lookup_aicpu
from .padding import _padding_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Padding op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
padding_op_info = AiCPURegOp("Padding") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.attr("pad_dim_size", "int") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(padding_op_info)
def _padding_aicpu():
"""Padding AiCPU register"""
return

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape, ParallelConcat,
Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
@ -137,6 +137,7 @@ __all__ = [
'GatherV2',
'SparseGatherV2',
'EmbeddingLookup',
'Padding',
'Concat',
'Pack',
'Unpack',

@ -602,6 +602,46 @@ class SparseGatherV2(GatherV2):
"""
class Padding(PrimitiveWithInfer):
"""
Extend the last dimension of input tensor from 1 to pad_dim_size, fill with 0.
Args:
pad_dim_size (int): The extend value of last dimension of x, must be positive.
Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of x should be at least 2.
The last dimension of x should be 1.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
>>> pad_dim_size = 4
>>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]]
"""
@prim_attr_register
def __init__(self, pad_dim_size=8):
"""init padding"""
validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
validator.check_integer("pad_dim_size", pad_dim_size, 0, Rel.GT, self.name)
self.pad_dim_size = pad_dim_size
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
validator.check_integer("rank of x", len(x_shape), 1, Rel.GT, self.name)
validator.check_integer("last dim of x", x_shape[-1], 1, Rel.EQ, self.name)
out_shape = x_shape
out_shape[-1] = self.pad_dim_size
out = {'shape': out_shape,
'dtype': x['dtype'],
'value': None}
return out
class Split(PrimitiveWithInfer):
"""
Splits input tensor into output_num of tensors along the given axis and output numbers.

@ -43,4 +43,4 @@ def test_net():
tx, ty = Tensor(x), Tensor(y)
output = mask(tx, ty)
print(output.asnumpy())
assert ([255, 255, 255, 255] == output.asnumpy()).all()
assert ([255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255] == output.asnumpy()).all()

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend")
class Net(nn.Cell):
def __init__(self, pad_dim_size):
super(Net, self).__init__()
self.padding = P.Padding(pad_dim_size)
def construct(self, x):
return self.padding(x)
def test_padding():
x = Tensor(np.array([[8], [10]]), mstype.int32)
padding = Net(4)
out = padding(x)
assert(out.asnumpy() == [[8, 0, 0, 0], [10, 0, 0, 0]]).all()
Loading…
Cancel
Save