commit
24ef3efcb8
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/mindir/space_batch_nd_attr_update.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBlockShapeDimNum = 2;
|
||||
constexpr auto kAttrBlockShape = "block_shape";
|
||||
constexpr auto kAttrPaddings = "paddings";
|
||||
constexpr auto kAttrCrops = "crops";
|
||||
} // namespace
|
||||
|
||||
const BaseRef SpaceToBatchNDAttrUpdate::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimSpaceToBatchND, X});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto block_shape = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape);
|
||||
if (block_shape.size() == kBlockShapeDimNum) {
|
||||
block_shape.insert(block_shape.begin(), 1);
|
||||
AnfAlgo::SetNodeAttr(kAttrBlockShape, MakeValue(block_shape), node);
|
||||
}
|
||||
auto paddings = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(node, kAttrPaddings);
|
||||
if (paddings.size() == kBlockShapeDimNum) {
|
||||
paddings.emplace(paddings.begin(), std::vector<int64_t>{0, 0});
|
||||
AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
const BaseRef BatchToSpaceNDAttrUpdate::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimBatchToSpaceND, X});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto block_shape = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape);
|
||||
if (block_shape.size() == kBlockShapeDimNum) {
|
||||
block_shape.insert(block_shape.begin(), 1);
|
||||
AnfAlgo::SetNodeAttr(kAttrBlockShape, MakeValue(block_shape), node);
|
||||
}
|
||||
auto crops = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(node, kAttrCrops);
|
||||
if (crops.size() == kBlockShapeDimNum) {
|
||||
crops.emplace(crops.begin(), std::vector<int64_t>{0, 0});
|
||||
AnfAlgo::SetNodeAttr(kAttrCrops, MakeValue(crops), node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SpaceToBatchNDAttrUpdate : public PatternProcessPass {
|
||||
public:
|
||||
explicit SpaceToBatchNDAttrUpdate(bool multigraph = true)
|
||||
: PatternProcessPass("space_to_batch_nd_attr_update", multigraph) {}
|
||||
~SpaceToBatchNDAttrUpdate() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class BatchToSpaceNDAttrUpdate : public PatternProcessPass {
|
||||
public:
|
||||
explicit BatchToSpaceNDAttrUpdate(bool multigraph = true)
|
||||
: PatternProcessPass("batch_to_space_nd_attr_update", multigraph) {}
|
||||
~BatchToSpaceNDAttrUpdate() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_
|
Loading…
Reference in new issue