!12947 Add MaxPool3D,MaxPool3DGrad,MaxPool3DGradGrad ops for Ascend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
pull/12947/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 54c37bcd61

@ -622,6 +622,14 @@ void TbeKernelJsonCreator::ParseAttrDefaultValue(const std::string &type, const
(*attr_obj)[kJValue] = attr_value;
} else if (type == kVTypeFloat) {
(*attr_obj)[kJValue] = std::stof(value);
} else if (type == kVTypeListInt) {
std::stringstream string_value(value);
std::string list_elem;
std::vector<int64_t> attr_value;
while (std::getline(string_value, list_elem, ',')) {
attr_value.push_back(std::stoi(list_elem));
}
(*attr_obj)[kJValue] = attr_value;
} else {
MS_LOG(EXCEPTION) << "Type: " << type << "not support";
}

@ -55,6 +55,7 @@
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
#include "backend/optimizer/ascend/ir_fission/space_to_depth_split.h"
#include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h"
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
@ -173,6 +174,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
@ -325,6 +327,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>());
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());

@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except i n compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h"
#include <vector>
#include <memory>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore::opt {
constexpr size_t kInputNum = 3;
constexpr size_t kFloat16Len = 2; // size of float16;
namespace {
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 get attr ksize
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "kernel_size");
auto data_format = AnfAlgo::GetNodeAttr<std::string>(cnode, "format");
if (data_format != kOpFormat_NCDHW) {
MS_LOG(ERROR) << "MaxPool3DGradGrad only support NCDHW.";
}
MS_LOG(DEBUG) << "ksize of MaxPool3DGradGrad:" << ksize;
int64_t D = ksize[2];
int64_t H = ksize[3];
int64_t W = ksize[4];
// 1 create tensor
std::vector<int64_t> assist_shape = {1, 1, D, H, W}; // shape:NCDHW
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
MS_EXCEPTION_IF_NULL(tensor_type);
tensor::DeviceInfo device_info{kOpFormat_NDC1HWC0, tensor_type};
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kFloat16->type_id(), assist_shape);
assist_tensor->set_device_info(device_info);
// 2 set value of tensor
auto data_ptr = assist_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
std::vector<float16> half_data;
int64_t dims = 1 * 1 * D * H * W;
int64_t counter = dims;
for (int64_t i = 0; i < dims; i++) {
half_data.emplace_back(float16(static_cast<float>(counter)));
counter--;
}
auto elem_num = dims * kFloat16Len;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(assist_tensor->data().nbytes()), half_data.data(), elem_num);
if (ret_code != 0) {
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
return nullptr;
}
return assist_tensor;
}
ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
tensor::TensorPtr assist_tensor = CreateTensor(node);
MS_EXCEPTION_IF_NULL(assist_tensor);
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
MS_EXCEPTION_IF_NULL(assist_const);
auto assist_abstract = assist_tensor->ToAbstract();
assist_const->set_abstract(assist_abstract);
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(assist_kernel_info);
assist_const->set_kernel_info(assist_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
op_builder.SetOutputsFormat({kOpFormat_NDC1HWC0});
op_builder.SetOutputsDeviceType({kNumberTypeFloat16});
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
return assist_const;
}
} // namespace
const BaseRef MaxPool3DGradGradFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradOpName);
return VectorRef({max_pool3d_grad_grad_prim, Xs});
}
const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != kInputNum + 1) {
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs";
return nullptr;
}
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
auto assist_const = CreateValueNode(cnode);
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
new_inputs.push_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Split MaxPool3DGradGrad op success.";
}
return new_cnode;
}
} // namespace mindspore::opt

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class MaxPool3DGradGradFission : public PatternProcessPass {
public:
explicit MaxPool3DGradGradFission(bool multigraph = true)
: PatternProcessPass("max_pool3d_grad_grad_fission", multigraph) {}
~MaxPool3DGradGradFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_MAX_POOL3D_GRAD_GRAD_FISSION_H_

@ -221,6 +221,7 @@ constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
constexpr auto kTensorAddOpName = "Add";
constexpr auto kMaxPool3DGradGradOpName = "MaxPool3DGradGrad";
constexpr auto kCastOpName = "Cast";
constexpr auto kGreaterEqualOpName = "GreaterEqual";
constexpr auto kAbsOpName = "Abs";

@ -249,6 +249,55 @@ def get_bprop_max_pool_grad(self):
return bprop
@bprop_getters.register(P.MaxPool3D)
def get_bprop_max_pool3d_grad(self):
"""Grad definition for `MaxPool3D` operation."""
max_pool3d_grad = G.MaxPool3DGrad(
kernel_size=self.kernel_size,
strides=self.strides,
pad_mode=self.pad_mode,
data_format=self.data_format)
def bprop(x, out, dout):
dx = max_pool3d_grad(x, out, dout)
return (dx,)
return bprop
@bprop_getters.register(G.MaxPool3DGrad)
def get_bprop_max_pool3d_grad_grad(self):
"""Grad definition for `MaxPool3Grad` operation."""
max_pool3d_grad_grad = G.MaxPool3DGradGrad(
kernel_size=self.kernel_size,
strides=self.strides,
pad_mode=self.pad_mode,
data_format=self.data_format)
def bprop(x, y, grad, out, dout):
dgrad = max_pool3d_grad_grad(x, y, dout)
return zeros_like(x), zeros_like(y), dgrad
return bprop
@bprop_getters.register(G.MaxPool3DGradGrad)
def get_bprop_max_pool3d_grad_grad_grad(self):
"""Grad definition for `MaxPool3GradGrad` operation."""
max_pool3d_grad = G.MaxPool3DGrad(
kernel_size=self.kernel_size,
strides=self.strides,
pad_mode=self.pad_mode,
data_format=self.data_format)
def bprop(x, y, grad, out, dout):
dgrad = max_pool3d_grad(x, y, dout)
return zeros_like(x), zeros_like(y), dgrad
return bprop
@bprop_getters.register(P.AvgPool)
def get_bprop_avg_pool_grad(self):
"""Grad definition for `AvgPool` operation."""

@ -65,6 +65,9 @@ from .max_pool import _max_pool_tbe
from .max_pool_grad import _max_pool_grad_tbe
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe
from .max_pool_with_argmax import _max_pool_with_argmax_tbe
from .max_pool3d import _max_pool_3d_tbe
from .max_pool3d_grad import _max_pool_3d_grad_tbe
from .max_pool3d_grad_grad import _max_pool_3d_grad_grad_tbe
from .mul import _mul_tbe
from .mul_ds import _mul_ds_tbe
from .real_div import _real_div_tbe

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaxPool3D op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
max_pool3d_op_info = TBERegOp("MaxPool3D") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool3d.so") \
.compute_cost(10) \
.kernel_name("max_pool3d") \
.partial_flag(True) \
.attr("kernel_size", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("pad_mode", "required", "str", "all") \
.attr("pad_list", "optional", "listInt", "all", "0,0,0") \
.attr("dilation", "optional", "listInt", "all", "1,1,1") \
.attr("ceil_mode", "optional", "int", "all", "0") \
.attr("format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(max_pool3d_op_info)
def _max_pool_3d_tbe():
"""MaxPool3D TBE register"""
return

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaxPool3DGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
max_pool3d_grad_op_info = TBERegOp("MaxPool3DGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool3d_grad.so") \
.compute_cost(10) \
.kernel_name("max_pool3d_grad") \
.partial_flag(True) \
.attr("kernel_size", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("format", "optional", "str", "all") \
.input(0, "orig_x", False, "required", "all") \
.input(1, "orig_y", False, "required", "all") \
.input(2, "grads", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0) \
.get_op_info()
@op_info_register(max_pool3d_grad_op_info)
def _max_pool_3d_grad_tbe():
"""MaxPool3DGrad TBE register"""
return

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaxPool3DGradGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
max_pool3d_grad_grad_op_info = TBERegOp("MaxPool3DGradGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("max_pool3d_grad_grad_d.so") \
.compute_cost(10) \
.kernel_name("max_pool3d_grad_grad_d") \
.partial_flag(True) \
.attr("kernel_size", "required", "listInt", "all") \
.attr("strides", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("format", "optional", "str", "all") \
.input(0, "orig_in", False, "required", "all") \
.input(1, "orig_out", False, "required", "all") \
.input(2, "grads", False, "required", "all") \
.input(3, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0,
DataType.F16_NDC1HWC0) \
.get_op_info()
@op_info_register(max_pool3d_grad_grad_op_info)
def _max_pool_3d_grad_grad_tbe():
"""MaxPool3DGradGrad TBE register"""
return

@ -69,7 +69,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
GeLU, Gelu, FastGeLU, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax,
LogSoftmax, MaxPool3D,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
@ -118,6 +118,7 @@ __all__ = [
'TensorAdd',
'Argmax',
'Argmin',
'MaxPool3D',
'ArgMaxWithValue',
'ArgMinWithValue',
'AddN',

@ -984,6 +984,89 @@ class MaxPoolGradGrad(_PoolGrad):
return x1_dtype
def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
"""
helper for get max pool3d grad pads by pad_mode
"""
def get_pad(origin_shape, ksize, stride):
tail = origin_shape % stride
pad = (ksize - tail) if tail > 0 else (ksize - stride)
pad = max(pad, 0)
pad1 = int(pad / 2)
pad2 = int(pad / 2) + pad % 2
return pad1, pad2
_, _, d, h, w = input_shape
_, _, kd, kh, kw = kernel_size
_, _, strd, strh, strw = strides
pads = (0, 0, 0, 0, 0, 0)
if pad_mode == 'SAME':
pads_d = get_pad(d, kd, strd)
pads_h = get_pad(h, kh, strh)
pads_w = get_pad(w, kw, strw)
pads = pads_d + pads_h + pads_w
return pads
class MaxPool3DGrad(PrimitiveWithInfer):
"""Gradients of the max pool3d operation."""
@prim_attr_register
def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("kernel_size", self.kernel_size)
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape, y_shape, grad_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
for pad in pad_list:
validator.check_non_negative_int(pad, 'element of pad_list', self.name)
self.add_prim_attr("pad_list", pad_list)
return x_shape
def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return mstype.tensor_type(mstype.float32)
class MaxPool3DGradGrad(PrimitiveWithInfer):
"""Gradients of the max pool3d grad operation."""
@prim_attr_register
def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("kernel_size", self.kernel_size)
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape, y_shape, grad_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
for pad in pad_list:
validator.check_non_negative_int(pad, 'element of pad_list', self.name)
self.add_prim_attr("pad_list", pad_list)
return y_shape
def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
args = {'x_dtype': x_dtype, 'y_dtype': y_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class MaximumGrad(Primitive):
"""Grad for maximum."""

@ -1833,6 +1833,105 @@ class MaxPoolWithArgmax(_Pool):
return x_dtype, argmax_dtype
class MaxPool3D(PrimitiveWithInfer):
r"""
Max pooling operation.
Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes.
Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, MaxPool outputs
regional maximum in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size
:math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows.
.. math::
\text{output}(N_i, C_j, d, h, w) =
\max_{l=0, \ldots, d_{ker}-1} \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
\text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
is an int number that represents height and width are both kernel_size, or a tuple
of three int numbers that represent depth, height and width respectively. Default: 1.
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Default: 1.
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. The height and width of the output will be the same as
the input. The total number of padding will be calculated in horizontal and vertical
directions and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possible largest height and width of output
will be returned without padding. Extra pixels will be discarded.
data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. Data type must be float16.
Outputs:
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the data type with `input`.
Raises:
TypeError: If `kernel_size` or `strides` is neither an int not a tuple.
TypeError: If `pad_mode` or `data_format` is not a string.
ValueError: If numbers in `kernel_size` or `strides` are not positive.
ValueError: If `pad_mode` is not one of 'same', 'valid'.
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3 or 5.
ValueError: If `data_format` is not 'NCDHW'.
Supported Platforms:
``Ascend``
Examples:
>>> input = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
>>> max_pool3d = ops.MaxPool3D(kernel_size=2, strides=1, pad_mode="valid")
>>> output = max_pool3d(input)
>>> print(output)
[[[[[10. 11.]]]
[[[22. 23.]]]]]
"""
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCDHW"):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.add_prim_attr("pad_mode", self.pad_mode)
self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("kernel_size", self.kernel_size)
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
batch, channel, input_d, input_h, input_w = x_shape
self.add_prim_attr("x_shape", x_shape)
_, _, kernel_d, kernel_h, kernel_w = self.kernel_size
_, _, stride_d, stride_h, stride_w = self.strides
if self.pad_mode == "VALID":
out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d)
out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
elif self.pad_mode == "SAME":
out_d = math.ceil(input_d / stride_d)
out_h = math.ceil(input_h / stride_h)
out_w = math.ceil(input_w / stride_w)
out_shape = [batch, channel, out_d, out_h, out_w]
_check_shape('output', out_shape, self.name)
return out_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class AvgPool(_Pool):
r"""
Average pooling operation.
@ -2097,8 +2196,8 @@ class BiasAdd(PrimitiveWithCheck):
def check_shape(self, x_shape, b_shape):
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
if self.format == "NCDHW" and len(x_shape) != 5:
raise ValueError("NCDHW format only support 5-dims input.")
if self.format == "NCDHW" and (len(x_shape) != 5 or context.get_context("device_target") != "Ascend"):
raise ValueError("NCDHW format only support 5-dims input in Ascend target.")
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1]
if np.all(np.array(x_shape) != -1):

@ -1696,6 +1696,14 @@ test_case_nn_ops = [
'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]],
'desc_bprop': [[3, 4, 6, 6]],
'skip': ['backward']}),
('MaxPool3D', {
'block': P.MaxPool3D(kernel_size=2, strides=2, pad_mode="VALID"),
'desc_inputs': [[100, 3, 28, 28, 28]],
'desc_bprop': [[100, 3, 14, 14, 14]]}),
('MaxPool3DGrad', {
'block': G.MaxPool3DGrad(kernel_size=2, strides=2, pad_mode="VALID"),
'desc_inputs': [[3, 4, 6, 6, 6], [3, 4, 3, 3, 3], [3, 4, 3, 3, 3]],
'desc_bprop': [[3, 4, 6, 6, 6]]}),
('AvgPool', {
'block': P.AvgPool(kernel_size=(2, 2), strides=(2, 2), pad_mode="VALID"),
'desc_inputs': [[100, 3, 28, 28]],

Loading…
Cancel
Save