!12427 [GraphKernel] Absorb real scalar tensor into graph kernel cnode.

From: @tronzhang
Reviewed-by: 
Signed-off-by:
pull/12427/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8f94d80f1f

@ -348,22 +348,10 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
// Create broadcast basic op.
auto dst_shape_vec = GetShape(atomic_add_node_);
if (dst_shape_vec.empty()) {
dst_shape_vec.push_back(1);
}
AnfNodePtrList atomic_clean_inputs = {NewValueNode(std::make_shared<Primitive>(kBroadcastToOpName)),
broadcast_input_node};
AnfNodePtrList atomic_clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
auto broadcast_to_node_inner = CreateCNode(
atomic_clean_inputs, new_sub_graph, {.format = format, .shape = dst_shape_vec, .type = GetType(atomic_add_node_)});
auto device_shape = AnfAlgo::GetOutputDeviceShape(atomic_add_node_, 0);
dst_shape_vec.clear();
if (device_shape.empty()) {
dst_shape_vec.push_back(1);
} else {
std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(dst_shape_vec), SizeToLong);
}
SetNodeAttrSafely("shape", MakeValue(dst_shape_vec), broadcast_to_node_inner);
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(atomic_add_node_)), broadcast_to_node_inner);
// Makeup sub-graph.
new_sub_graph->set_output(broadcast_to_node_inner);

@ -15,6 +15,7 @@
*/
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include <algorithm>
#include <map>
#include <set>
#include <tuple>
@ -121,6 +122,84 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
return true;
}
bool TensorElementAllTheSame(const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->DataSize() == 1) {
return true;
}
auto data = static_cast<char *>(tensor->data_c());
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
auto total_cnt = static_cast<size_t>(tensor->DataSize());
for (size_t i = 1; i < total_cnt; ++i) {
for (size_t ei = 0; ei < itemsize; ++ei) {
if (data[ei] != data[i * itemsize + ei]) {
return false;
}
}
}
return true;
}
AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
MS_EXCEPTION_IF_NULL(tensor);
auto type_id = tensor->data_type();
ShapeVector new_shape;
auto origin_ndim = static_cast<size_t>(tensor->DataDim());
for (size_t i = 0; i < origin_ndim; ++i) {
new_shape.push_back(1);
}
tensor::TensorPtr scalar_tensor = std::make_shared<tensor::Tensor>(type_id, new_shape);
scalar_tensor->set_device_info(tensor->device_info());
auto data_ptr = scalar_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
if (memcpy_s(data_ptr, static_cast<size_t>(itemsize), tensor->data_c(), itemsize) != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data from tensor into scalar.";
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(scalar_tensor);
new_value_node->set_abstract(scalar_tensor->ToAbstract());
new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{GetFormat(value_node)});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_id});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
void ReplaceTensorWithScalar(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &scalar_tensors) {
MS_EXCEPTION_IF_NULL(fg);
if (scalar_tensors.empty()) {
return;
}
auto sub_mng = fg->manager();
if (sub_mng == nullptr) {
sub_mng = Manage(fg, true);
fg->set_manager(sub_mng);
}
std::map<AnfNodePtr, AnfNodePtr> to_be_replaced;
for (auto scalar_tensor_node : scalar_tensors) {
auto scalar = ConvertToScalarTensor(scalar_tensor_node);
auto format = GetFormat(scalar_tensor_node);
auto dst_shape_vec = GetShape(scalar_tensor_node);
AnfNodePtrList new_broadcast_inputs = {NewValueNode(prim::kPrimBroadcastTo), scalar};
auto broadcast_node = CreateCNode(new_broadcast_inputs, fg,
{.format = format, .shape = dst_shape_vec, .type = GetType(scalar_tensor_node)});
auto device_shape = GetDeviceShape(scalar_tensor_node);
SetNodeAttrSafely("shape", MakeValue(device_shape), broadcast_node);
to_be_replaced[scalar_tensor_node] = broadcast_node;
}
for (auto [old_value_node, new_node] : to_be_replaced) {
sub_mng->Replace(old_value_node, new_node);
}
}
} // namespace
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
@ -128,20 +207,28 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
auto nodes = TopoSort(fg->get_return());
OrderedMap<ValuePtr, AnfNodePtrList> vmap;
std::vector<AnfNodePtr> scalar_tensors;
for (const auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto tnode = inputs[i];
const auto &tnode = inputs[i];
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
if (tensor && (tensor->DataSize() > 1)) {
if (tensor == nullptr || tensor->DataSize() == 1) {
continue;
}
if (TensorElementAllTheSame(tensor)) {
scalar_tensors.emplace_back(tnode);
} else {
vmap[GetValueNode(tnode)].push_back(tnode);
}
}
}
ReplaceTensorWithScalar(fg, scalar_tensors);
if (vmap.empty()) {
return false;
}
@ -169,6 +256,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
inputs.push_back(vnode);
}
return true;
}
@ -660,7 +748,22 @@ ShapeVector GetShape(const AnfNodePtr &node) {
if (shape == nullptr || !shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope();
}
return shape->cast<abstract::ShapePtr>()->shape();
auto shape_vec = shape->cast<abstract::ShapePtr>()->shape();
if (shape_vec.empty()) {
shape_vec.push_back(1);
}
return shape_vec;
}
ShapeVector GetDeviceShape(const AnfNodePtr &node) {
ShapeVector res_device_shape;
auto device_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
if (device_shape.empty()) {
res_device_shape.push_back(1);
} else {
std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(res_device_shape), SizeToLong);
}
return res_device_shape;
}
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {

@ -85,6 +85,7 @@ void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNo
std::string GetFormat(const AnfNodePtr &node);
TypePtr GetType(const AnfNodePtr &node);
ShapeVector GetShape(const AnfNodePtr &node);
ShapeVector GetDeviceShape(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@ namespace mindspore {
namespace opt {
class TensorPromotion : public Pass {
public:
TensorPromotion() : Pass("graph_kernel_tensor_promotion") {}
TensorPromotion() : Pass("tensor_promotion") {}
~TensorPromotion() override = default;
bool Run(const FuncGraphPtr &func_graph);
};

Loading…
Cancel
Save