Add a pass `tensor_promotion`. Fix a bug in CreateKernelInfoFromNewParameter, which reset the KernelInfo by mistake. what's more: Update akg Fixbug in model_builder when reduce axis is an interger.pull/7892/head
parent
b3b553245f
commit
b6c2812a29
@ -1 +1 @@
|
||||
Subproject commit 03ef896b90a34ebdb7eeb3fa77d7d4252d021011
|
||||
Subproject commit f308919c39811c2c3e07fb0dcc8054a533c84cbc
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool TensorPromotion::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
|
||||
bool changed = false;
|
||||
for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
|
||||
auto node = *iter;
|
||||
if (!AnfAlgo::IsGraphKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto args = node->cast<CNodePtr>()->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(args[kAnfPrimitiveIndex]);
|
||||
if (!ConvertNonscalarTensorToParameter(fg, &args)) {
|
||||
continue;
|
||||
}
|
||||
AnfNodePtrList inputs, outputs;
|
||||
inputs.insert(inputs.end(), args.begin() + 1, args.end());
|
||||
kernel::GetFuncGraphOutputNodes(fg, &outputs);
|
||||
auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs, false);
|
||||
SetNewKernelInfo(new_cnode, fg, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||
mng->Replace(node, new_cnode);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TENSOR_PROMOTION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TENSOR_PROMOTION_H_
|
||||
#include <memory>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TensorPromotion : public Pass {
|
||||
public:
|
||||
TensorPromotion() : Pass("graph_kernel_tensor_promotion") {}
|
||||
~TensorPromotion() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
using TensorPromotionPtr = std::shared_ptr<TensorPromotion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TENSOR_PROMOTION_H_
|
Loading…
Reference in new issue