!732 dock AcoshGrad for GE and AvgPool AvgPoolGrad for VM

Merge pull request !732 from zhangbuxue/dock_AcoshGrad_AvgPool_AvgPoolGrad_for_vm
pull/732/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 9ffb3993e8

@ -38,6 +38,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"reduce_mean", "reduce_mean_d"},
{"reduce_max", "reduce_max_d"},
{"reduce_min", "reduce_min_d"},
{"avg_pool_grad", "avg_pool_grad_d"},
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
{"depthwise_conv2d_native", "depthwise_conv2d"},

@ -170,6 +170,7 @@ const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");

@ -178,6 +178,7 @@ extern const PrimitivePtr kPrimFusedBatchNorm;
extern const PrimitivePtr kPrimConv2D;
extern const PrimitivePtr kPrimMaxPool;
extern const PrimitivePtr kPrimMaxPoolGrad;
extern const PrimitivePtr kPrimAvgPoolGrad;
extern const PrimitivePtr kPrimFusedBatchNormGrad;
extern const PrimitivePtr kPrimReluGrad;
extern const PrimitivePtr kPrimConv2DBackpropInput;

@ -25,6 +25,7 @@ namespace mindspore {
namespace opt {
ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimCast->name(), {1});
Register(prim::kPrimAvgPoolGrad->name(), {0});
Register(prim::kPrimConv2DBackpropInput->name(), {2});
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1});

@ -178,6 +178,7 @@ const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
const char kNameAcosh[] = "Acosh";
const char kNameAcoshGrad[] = "AcoshGrad";
const char kNameFloorMod[] = "FloorMod";
const char kNameSpaceToDepth[] = "SpaceToDepth";
const char kNameDepthToSpace[] = "DepthToSpace";
@ -375,6 +376,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)},
{string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)},
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},

@ -357,6 +357,11 @@ INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Acosh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}};
// AcoshGrad
INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}};
// Floor
INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Floor) = EMPTY_ATTR_MAP;

@ -327,13 +327,15 @@ DECLARE_OP_ADAPTER(Const)
DECLARE_OP_USE_OUTPUT(Const)
DECLARE_OP_ADAPTER(Cos)
DECLARE_OP_USE_OUTPUT(Cos)
DECLARE_OP_ADAPTER(Acos)
DECLARE_OP_USE_OUTPUT(Acos)
DECLARE_OP_ADAPTER(AcosGrad)
DECLARE_OP_USE_OUTPUT(AcosGrad)
DECLARE_OP_ADAPTER(Acosh)
DECLARE_OP_USE_OUTPUT(Acosh)
DECLARE_OP_ADAPTER(AcoshGrad)
DECLARE_OP_USE_OUTPUT(AcoshGrad)
DECLARE_OP_ADAPTER(Floor)
DECLARE_OP_USE_OUTPUT(Floor)

@ -21,6 +21,7 @@ from mindspore._checkparam import check_int_positive, check_bool
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.functional import identity
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore.common.api import ms_function
@ -480,7 +481,7 @@ class Unfold(Cell):
"""
def __init__(self, ksizes, strides, rates, padding="valid"):
super(Unfold, self).__init__()
self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding)
self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
self.transpose = P.Transpose()
self.format_NHWC = (0, 2, 3, 1)
self.format_NCHW = (0, 3, 1, 2)

File diff suppressed because it is too large Load Diff

@ -151,3 +151,5 @@ from .greater_equal import _greater_equal_tbe
from .not_equal import _not_equal_tbe
from .floor_mod import _floor_mod_tbe
from .scatter_nd_update import _scatter_nd_update_tbe
from .avg_pool import _avg_pool_tbe
from .avg_pool_grad import _avg_pool_grad_tbe

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AvgPool op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
avg_pool_op_info = TBERegOp("AvgPool") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("avg_pool.so") \
.compute_cost(10) \
.kernel_name("avg_pool") \
.partial_flag(True) \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.get_op_info()
@op_info_register(avg_pool_op_info)
def _avg_pool_tbe():
"""AvgPool TBE register"""
return

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AvgPoolGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
avg_pool_grad_op_info = TBERegOp("AvgPoolGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("avg_pool_grad_d.so") \
.compute_cost(10) \
.kernel_name("avg_pool_grad_d") \
.partial_flag(True) \
.attr("x_origin", "required", "listInt", "all") \
.attr("ksize", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("padding", "required", "str", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "input_grad", False, "required", "all") \
.input(1, "mean_matrix", False, "optional", "all") \
.input(2, "kernel_matrix", False, "optional", "all") \
.output(0, "out_grad", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \
.get_op_info()
@op_info_register(avg_pool_grad_op_info)
def _avg_pool_grad_tbe():
"""AvgPoolGrad TBE register"""
return

@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss,
LogSoftmax,
MaxPool, ExtractImagePatches,
MaxPool,
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
@ -89,7 +89,6 @@ __all__ = [
'Sqrt',
'Square',
'Conv2D',
'ExtractImagePatches',
'Flatten',
'MaxPoolWithArgmax',
'FusedBatchNorm',

@ -59,6 +59,23 @@ class ACosGrad(PrimitiveWithInfer):
return x
class AcoshGrad(PrimitiveWithInfer):
"""Performs grad of Acosh operation."""
@prim_attr_register
def __init__(self):
"""init AcoshGrad"""
def infer_shape(self, x, dout):
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x
class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""

@ -0,0 +1,98 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inner operators."""
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
def infer_shape(self, input_x):
"""infer shape"""
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
if self.padding == "VALID":
out_row = \
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
out_col = \
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
else:
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
return input_x

@ -2705,82 +2705,6 @@ class ApplyFtrl(PrimitiveWithInfer):
return var_type
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extract patches from images.
The input tensor must be a 4-D tensor and the data format is NHWC.
Args:
ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
and the format is [1, ksize_row, ksize_col, 1].
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
not case sensitive. Default: "valid".
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
- valid: Means that the patch area taken must be completely contained in the original image.
Inputs:
- **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
data type is int8, float16, uint8.
Outputs:
Tensor, a 4-D tensor whose data type is same as 'input_x',
and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
"""
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""
def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.")
if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
f"is {arg_val[2]}")
_check_tuple_or_list("ksize", ksizes, self.name)
_check_tuple_or_list("stride", strides, self.name)
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
def infer_shape(self, input_x):
in_batch, in_row, in_col, in_depth = input_x
_, ksize_row, ksize_col, _ = self.ksizes
_, stride_row, stride_col, _ = self.strides
_, rate_row, rate_col, _ = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
if self.padding == "VALID":
out_row = \
(in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
out_col = \
(in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
else:
out_row = (in_row - 1) // stride_row + 1
out_col = (in_col - 1) // stride_col + 1
out_shape = [out_batch, out_row, out_col, out_depth]
return out_shape
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
return input_x
class ConfusionMulGrad(PrimitiveWithInfer):
"""
`output0` is the result of which input0 dot multily input1.

@ -265,8 +265,8 @@ test_case_math_ops = [
'desc_bprop': [[2, 3]]}),
('Acosh', {
'block': P.Acosh(),
'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16))],
'skip': ['backward']}),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('Sin', {
'block': P.Sin(),
'desc_inputs': [[2, 3]],

Loading…
Cancel
Save