parent
49da4e799c
commit
3618b0843d
@ -0,0 +1,92 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/pass/replace_node_by_proxy.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "device/kernel_info.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<std::string> inputs_device_format;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
|
||||
builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||
builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
|
||||
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
builder.SetOutputsFormat(outputs_device_format);
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)};
|
||||
proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs);
|
||||
MS_EXCEPTION_IF_NULL(proxy_node);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
proxy_node->set_kernel_info(kernel_info);
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
|
||||
AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node);
|
||||
AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
proxy_node->set_abstract(abstract_tuple);
|
||||
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(cnode);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get());
|
||||
|
||||
if (!manager->Replace(cnode, proxy_node)) {
|
||||
MS_LOG(EXCEPTION) << "Replace node by proxy node failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "pre_activate/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReplaceNodeByProxy : public Pass {
|
||||
public:
|
||||
explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {}
|
||||
~ReplaceNodeByProxy() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_
|
Loading…
Reference in new issue