!13818 revert nn.BatchNorm3d.
From: @liu_xiao_93 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuipull/13818/MERGE
commit
fcce705bcb
@ -1,85 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fusion/batchnorm_grad_to_batchnorm3d_grad.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBN3DGradInputXIndex = 2;
|
||||
CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm_grad);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) {
|
||||
inputs.push_back(batchnorm_grad->input(i));
|
||||
}
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(batchnorm_grad->scope());
|
||||
new_node->set_abstract(batchnorm_grad->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(batchnorm_grad, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm_grad) {
|
||||
MS_EXCEPTION_IF_NULL(batchnorm_grad);
|
||||
if (AnfAlgo::GetInputTensorNum(batchnorm_grad) < kBNGradInputTensorNum) {
|
||||
MS_LOG(INFO) << "BatchNormGrad's input less than " << kBNGradInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm_grad, kAttrFormat);
|
||||
const auto &ori_inputs = batchnorm_grad->inputs();
|
||||
auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3DGradInputXIndex], 0);
|
||||
if (format != kOpFormat_NCDHW || x_shape.size() != 5) {
|
||||
MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNormGrad is 5, then do fusion. But format is: "
|
||||
<< format << ", size of x_shape is: " << x_shape.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormGrad2BatchNorm3DGRAD::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
VectorRef pattern({prim::kPrimBatchNormGrad, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNormGrad2BatchNorm3DGRAD::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode_bn_grad = node->cast<CNodePtr>();
|
||||
if (!NeedFusion(graph, cnode_bn_grad)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto bn_3d_grad = CreateBatchNorm3DGrad(graph, cnode_bn_grad);
|
||||
TransferDepend(cnode_bn_grad, graph, bn_3d_grad);
|
||||
return bn_3d_grad;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchNormGrad2BatchNorm3DGRAD : public PatternProcessPass {
|
||||
public:
|
||||
explicit BatchNormGrad2BatchNorm3DGRAD(bool multigraph = true)
|
||||
: PatternProcessPass("batchnorm_grad_to_batchnorm3d_grad", multigraph) {}
|
||||
~BatchNormGrad2BatchNorm3DGRAD() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_GRAD_TO_BATCHNORM_3D_GRAD_H_
|
@ -1,104 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_batchnorm3d.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBN3InputXIndex = 1;
|
||||
constexpr size_t kBn3DTrainInputTensorNum = 3;
|
||||
CNodePtr CreateBatchNorm3D(const FuncGraphPtr &graph, const CNodePtr &batchnorm) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNorm3DOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining);
|
||||
for (size_t i = 1; i < batchnorm->size(); ++i) {
|
||||
if (is_training && i > kBn3DTrainInputTensorNum) {
|
||||
continue;
|
||||
} else {
|
||||
inputs.push_back(batchnorm->input(i));
|
||||
}
|
||||
}
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(batchnorm->scope());
|
||||
new_node->set_abstract(batchnorm->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(batchnorm, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
bool NeedFusion(const FuncGraphPtr &graph, const CNodePtr &batchnorm) {
|
||||
MS_EXCEPTION_IF_NULL(batchnorm);
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) {
|
||||
MS_LOG(INFO) << "BatchNorm has no is_training attr.";
|
||||
return false;
|
||||
}
|
||||
auto is_training = AnfAlgo::GetNodeAttr<bool>(batchnorm, kAttrIsTraining);
|
||||
auto format = AnfAlgo::GetNodeAttr<std::string>(batchnorm, kAttrFormat);
|
||||
if (is_training && format == kOpFormat_NCDHW) {
|
||||
if (AnfAlgo::GetInputTensorNum(batchnorm) < kBn3DTrainInputTensorNum) {
|
||||
MS_LOG(INFO) << "When data format is NCDHW and is_training is true, BatchNorm's input less than "
|
||||
<< kBn3DTrainInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) {
|
||||
MS_LOG(INFO) << "BatchNorm's input less than " << kBnInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const auto &ori_inputs = batchnorm->inputs();
|
||||
auto x_shape = AnfAlgo::GetOutputInferShape(ori_inputs[kBN3InputXIndex], 0);
|
||||
if (format != kOpFormat_NCDHW || x_shape.size() != 5) {
|
||||
MS_LOG(INFO) << "Only format is NCDHW and the input dim of BatchNorm is 5, then do fusion. But format is: "
|
||||
<< format << ", size of x_shape is: " << x_shape.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNorm2BatchNorm3D::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
VectorRef pattern({prim::kPrimBatchNorm, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNorm2BatchNorm3D::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode_bn = node->cast<CNodePtr>();
|
||||
if (!NeedFusion(graph, cnode_bn)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto bn_3d = CreateBatchNorm3D(graph, cnode_bn);
|
||||
TransferDepend(cnode_bn, graph, bn_3d);
|
||||
return bn_3d;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,33 +0,0 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchNorm2BatchNorm3D : public PatternProcessPass {
|
||||
public:
|
||||
explicit BatchNorm2BatchNorm3D(bool multigraph = true) : PatternProcessPass("batchnorm_to_batchnorm3d", multigraph) {}
|
||||
~BatchNorm2BatchNorm3D() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BATCHNORM_TO_BATCHNORM_3D_H_
|
@ -1,51 +0,0 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BatchNorm3D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
batch_norm3d_op_info = TBERegOp("BatchNorm3D") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batch_norm3d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batch_norm3d") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("format", "optional", "str", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "scale", False, "required", "all", reshape_type="C") \
|
||||
.input(2, "offset", False, "required", "all", reshape_type="C") \
|
||||
.input(3, "mean", False, "optional", "all", reshape_type="C") \
|
||||
.input(4, "variance", False, "optional", "all", reshape_type="C") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "batch_mean", False, "required", "all") \
|
||||
.output(2, "batch_variance", False, "required", "all") \
|
||||
.output(3, "reserve_space_1", False, "optional", "all") \
|
||||
.output(4, "reserve_space_2", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batch_norm3d_op_info)
|
||||
def _batch_norm3d_tbe():
|
||||
"""BatchNorm3D TBE register"""
|
||||
return
|
@ -1,51 +0,0 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BatchNorm3DGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
batch_norm3d_grad_op_info = TBERegOp("BatchNorm3DGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batch_norm3d_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batch_norm3d_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("format", "optional", "str", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.input(0, "y_backprop", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "scale", False, "required", "all", reshape_type="C") \
|
||||
.input(3, "reserve_space_1", False, "optional", "all") \
|
||||
.input(4, "reserve_space_2", False, "optional", "all") \
|
||||
.output(0, "x_backprop", False, "required", "all") \
|
||||
.output(1, "scale_backprop", False, "required", "all") \
|
||||
.output(2, "offset_backprop", False, "required", "all") \
|
||||
.output(3, "reserve_space_4", False, "optional", "all") \
|
||||
.output(4, "reserve_space_5", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batch_norm3d_grad_op_info)
|
||||
def _batch_norm3d_grad_tbe():
|
||||
"""BatchNorm3DGrad TBE register"""
|
||||
return
|
Loading…
Reference in new issue