diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 73171e5c9b..d8b105dfca 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -64,8 +64,6 @@ #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" #include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" #include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" -#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" -#include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h" @@ -278,8 +276,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -325,8 +321,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc deleted file mode 100644 index 82debab205..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h" -#include -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "base/core_ops.h" -#include "abstract/abstract_value.h" -#include "backend/optimizer/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBN3DGradInputXIndex = 2; -CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm_grad); - auto prim = std::make_shared(kBatchNorm3DGradOpName); - std::vector inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) { - inputs.push_back(batchnorm_grad->input(i)); - } - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnorm_grad->scope()); - new_node->set_abstract(batchnorm_grad->abstract()); - AnfAlgo::CopyNodeAttrs(batchnorm_grad, new_node); - return new_node; -} - -bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) { - MS_EXCEPTION_IF_NULL(batchnorm_grad); - if (AnfAlgo::GetInputTensorNum(batchnorm_grad) < kBNGradInputTensorNum) { - MS_LOG(INFO) << "BatchNormGrad's input less than " << kBNGradInputTensorNum; - return false; - } - auto format = AnfAlgo::GetNodeAttr(batchnorm_grad, kAttrFormat); - const auto &ori_inputs = batchnorm_grad->inputs(); - auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3DGradInputXIndex], 0); - if (format != kOpFormat_NCDHW || x_shape.size() != 5) { - MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNormGrad is 5, then do fusion. But format is: " - << format << ", size of x_shape is: " << x_shape.size(); - return false; - } - return true; -} -} // namespace - -const BaseRef BatchNormGrad2BatchNorm3DGRAD::DefinePattern() const { - VarPtr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - VectorRef pattern({prim::kPrimBatchNormGrad, Xs}); - return pattern; -} - -const AnfNodePtr BatchNormGrad2BatchNorm3DGRAD::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode_bn_grad = node->cast(); - if (!NeedFusion(graph, cnode_bn_grad)) { - return nullptr; - } - auto bn_3d_grad = CreateBatchNorm3DGrad(graph, cnode_bn_grad); - TransferDepend(cnode_bn_grad, graph, bn_3d_grad); - return bn_3d_grad; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h deleted file mode 100644 index e2f7530fba..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ - -#include -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormGrad2BatchNorm3DGRAD : public PatternProcessPass { - public: - explicit BatchNormGrad2BatchNorm3DGRAD(bool multigraph = true) - : PatternProcessPass("batchnorm_grad_to_batchnorm3d_grad", multigraph) {} - ~BatchNormGrad2BatchNorm3DGRAD() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc deleted file mode 100644 index a01f752424..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h" -#include -#include -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "base/core_ops.h" -#include "abstract/abstract_value.h" -#include "backend/optimizer/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBN3InputXIndex = 1; -constexpr size_t kBn3DTrainInputTensorNum = 3; -CNodePtr CreateBatchNorm3D(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm); - auto prim = std::make_shared(kBatchNorm3DOpName); - std::vector inputs = {NewValueNode(prim)}; - auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); - for (size_t i = 1; i < batchnorm->size(); ++i) { - if (is_training && i > kBn3DTrainInputTensorNum) { - continue; - } else { - inputs.push_back(batchnorm->input(i)); - } - } - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnorm->scope()); - new_node->set_abstract(batchnorm->abstract()); - AnfAlgo::CopyNodeAttrs(batchnorm, new_node); - return new_node; -} - -bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { - MS_EXCEPTION_IF_NULL(batchnorm); - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { - MS_LOG(INFO) << "BatchNorm has no is_training attr."; - return false; - } - auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); - auto format = AnfAlgo::GetNodeAttr(batchnorm, kAttrFormat); - if (is_training && format == kOpFormat_NCDHW) { - if (AnfAlgo::GetInputTensorNum(batchnorm) < kBn3DTrainInputTensorNum) { - MS_LOG(INFO) << "When data format is NCDHW and is_training is true, BatchNorm's input less than " - << kBn3DTrainInputTensorNum; - return false; - } - } else { - if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { - MS_LOG(INFO) << "BatchNorm's input less than " << kBnInputTensorNum; - return false; - } - } - const auto &ori_inputs = batchnorm->inputs(); - auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3InputXIndex], 0); - if (format != kOpFormat_NCDHW || x_shape.size() != 5) { - MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNorm is 5, then do fusion. But format is: " - << format << ", size of x_shape is: " << x_shape.size(); - return false; - } - return true; -} -} // namespace - -const BaseRef BatchNorm2BatchNorm3D::DefinePattern() const { - VarPtr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - VectorRef pattern({prim::kPrimBatchNorm, Xs}); - return pattern; -} - -const AnfNodePtr BatchNorm2BatchNorm3D::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode_bn = node->cast(); - if (!NeedFusion(graph, cnode_bn)) { - return nullptr; - } - auto bn_3d = CreateBatchNorm3D(graph, cnode_bn); - TransferDepend(cnode_bn, graph, bn_3d); - return bn_3d; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h deleted file mode 100644 index 19203ac484..0000000000 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ - -#include -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNorm2BatchNorm3D : public PatternProcessPass { - public: - explicit BatchNorm2BatchNorm3D(bool multigraph = true) : PatternProcessPass("batchnorm_to_batchnorm3d", multigraph) {} - ~BatchNorm2BatchNorm3D() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5d418ac108..6077ea46e9 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -141,8 +141,6 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign"; constexpr auto kBatchNormGradOpName = "BatchNormGrad"; constexpr auto kBNInferOpName = "BNInfer"; -constexpr auto kBatchNorm3DOpName = "BatchNorm3D"; -constexpr auto kBatchNorm3DGradOpName = "BatchNorm3DGrad"; constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign"; constexpr auto kResizeNearestNeighborGradOpName = "ResizeNearestNeighborGrad"; diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index f51b3cd6ed..0d7dc01bfd 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -65,12 +65,7 @@ class _BatchNorm(Cell): if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) self.input_dims = input_dims - if self.input_dims == "3d": - self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) - else: - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) - if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": - raise ValueError("NCDHW format only support in Ascend target.") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics @@ -441,7 +436,7 @@ def _check_3d_shape(input_shape): raise ValueError("For BatchNorm3d, input data must be 5-dimensional.") -class BatchNorm3d(_BatchNorm): +class BatchNorm3d(Cell): r""" Batch normalization layer over a 5D input. @@ -493,7 +488,7 @@ class BatchNorm3d(_BatchNorm): ValueError: If `data_format` is not 'NCDHW'. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = nn.BatchNorm3d(num_features=3) @@ -515,21 +510,27 @@ class BatchNorm3d(_BatchNorm): moving_var_init='ones', use_batch_statistics=None, data_format='NCDHW'): - super(BatchNorm3d, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - input_dims='3d', - data_format=data_format) + super(BatchNorm3d, self).__init__() + self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) + self.reshape = P.Reshape() + self.bn2d = BatchNorm2d(num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + gamma_init=gamma_init, + beta_init=beta_init, + moving_mean_init=moving_mean_init, + moving_var_init=moving_var_init, + use_batch_statistics=use_batch_statistics, + data_format="NCHW") - def _check_data_dim(self, x): - if x.ndim != 5: - pass + def construct(self, input_x): + x_shape = F.shape(input_x) + _check_3d_shape(x_shape) + input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) + bn2d_out = self.bn2d(input_x) + bn3d_out = self.reshape(bn2d_out, x_shape) + return bn3d_out class GlobalBatchNorm(_BatchNorm): diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 700903ce0a..84458a4903 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -48,8 +48,6 @@ from .assign_sub import _assign_sub_tbe from .batch_matmul import _batch_matmul_tbe from .batchnorm import _batch_norm_tbe from .batchnorm_grad import _batch_norm_grad_tbe -from .batchnorm3d import _batch_norm3d_tbe -from .batchnorm3d_grad import _batch_norm3d_grad_tbe from .bias_add import _bias_add_tbe from .bias_add_grad import _bias_add_grad_tbe from .cast import _cast_tbe diff --git a/mindspore/ops/_op_impl/tbe/batchnorm3d.py b/mindspore/ops/_op_impl/tbe/batchnorm3d.py deleted file mode 100644 index bd66a35352..0000000000 --- a/mindspore/ops/_op_impl/tbe/batchnorm3d.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BatchNorm3D op""" -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -batch_norm3d_op_info = TBERegOp("BatchNorm3D") \ - .fusion_type("OPAQUE") \ - .async_flag(False) \ - .binfile_name("batch_norm3d.so") \ - .compute_cost(10) \ - .kernel_name("batch_norm3d") \ - .partial_flag(True) \ - .attr("epsilon", "optional", "float", "all") \ - .attr("format", "optional", "str", "all") \ - .attr("is_training", "optional", "bool", "all") \ - .input(0, "x", False, "required", "all") \ - .input(1, "scale", False, "required", "all", reshape_type="C") \ - .input(2, "offset", False, "required", "all", reshape_type="C") \ - .input(3, "mean", False, "optional", "all", reshape_type="C") \ - .input(4, "variance", False, "optional", "all", reshape_type="C") \ - .output(0, "y", False, "required", "all") \ - .output(1, "batch_mean", False, "required", "all") \ - .output(2, "batch_variance", False, "required", "all") \ - .output(3, "reserve_space_1", False, "optional", "all") \ - .output(4, "reserve_space_2", False, "optional", "all") \ - .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ - .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ - .get_op_info() - - -@op_info_register(batch_norm3d_op_info) -def _batch_norm3d_tbe(): - """BatchNorm3D TBE register""" - return diff --git a/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py b/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py deleted file mode 100644 index 57019856e2..0000000000 --- a/mindspore/ops/_op_impl/tbe/batchnorm3d_grad.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BatchNorm3DGrad op""" -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -batch_norm3d_grad_op_info = TBERegOp("BatchNorm3DGrad") \ - .fusion_type("OPAQUE") \ - .async_flag(False) \ - .binfile_name("batch_norm3d_grad.so") \ - .compute_cost(10) \ - .kernel_name("batch_norm3d_grad") \ - .partial_flag(True) \ - .attr("epsilon", "optional", "float", "all") \ - .attr("format", "optional", "str", "all") \ - .attr("is_training", "optional", "bool", "all") \ - .input(0, "y_backprop", False, "required", "all") \ - .input(1, "x", False, "required", "all") \ - .input(2, "scale", False, "required", "all", reshape_type="C") \ - .input(3, "reserve_space_1", False, "optional", "all") \ - .input(4, "reserve_space_2", False, "optional", "all") \ - .output(0, "x_backprop", False, "required", "all") \ - .output(1, "scale_backprop", False, "required", "all") \ - .output(2, "offset_backprop", False, "required", "all") \ - .output(3, "reserve_space_4", False, "optional", "all") \ - .output(4, "reserve_space_5", False, "optional", "all") \ - .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ - .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, - DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ - .get_op_info() - - -@op_info_register(batch_norm3d_grad_op_info) -def _batch_norm3d_grad_tbe(): - """BatchNorm3DGrad TBE register""" - return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 908e952fdd..0bc44bcce0 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -195,7 +195,7 @@ class BatchNormGrad(PrimitiveWithInfer): def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'): self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) - self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) + self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 8a815f3ebc..49c284162f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1201,11 +1201,9 @@ class BatchNorm(PrimitiveWithInfer): validator.check_value_type('is_training', is_training, (bool,), self.name) validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") - if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": - raise ValueError("NCDHW format only support in Ascend target.") self.add_prim_attr('data_format', self.format) self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])