|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <tuple>
|
|
|
|
@ -121,6 +122,84 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const
|
|
|
|
|
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TensorElementAllTheSame(const tensor::TensorPtr &tensor) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
|
if (tensor->DataSize() == 1) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto data = static_cast<char *>(tensor->data_c());
|
|
|
|
|
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
|
|
|
|
|
auto total_cnt = static_cast<size_t>(tensor->DataSize());
|
|
|
|
|
for (size_t i = 1; i < total_cnt; ++i) {
|
|
|
|
|
for (size_t ei = 0; ei < itemsize; ++ei) {
|
|
|
|
|
if (data[ei] != data[i * itemsize + ei]) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
|
|
|
|
|
auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
|
auto type_id = tensor->data_type();
|
|
|
|
|
ShapeVector new_shape;
|
|
|
|
|
auto origin_ndim = static_cast<size_t>(tensor->DataDim());
|
|
|
|
|
for (size_t i = 0; i < origin_ndim; ++i) {
|
|
|
|
|
new_shape.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
tensor::TensorPtr scalar_tensor = std::make_shared<tensor::Tensor>(type_id, new_shape);
|
|
|
|
|
scalar_tensor->set_device_info(tensor->device_info());
|
|
|
|
|
auto data_ptr = scalar_tensor->data_c();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data_ptr);
|
|
|
|
|
auto itemsize = static_cast<size_t>(tensor->data().itemsize());
|
|
|
|
|
if (memcpy_s(data_ptr, static_cast<size_t>(itemsize), tensor->data_c(), itemsize) != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to copy data from tensor into scalar.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(scalar_tensor);
|
|
|
|
|
new_value_node->set_abstract(scalar_tensor->ToAbstract());
|
|
|
|
|
new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{GetFormat(value_node)});
|
|
|
|
|
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_id});
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
|
|
|
|
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReplaceTensorWithScalar(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &scalar_tensors) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
|
|
|
if (scalar_tensors.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sub_mng = fg->manager();
|
|
|
|
|
if (sub_mng == nullptr) {
|
|
|
|
|
sub_mng = Manage(fg, true);
|
|
|
|
|
fg->set_manager(sub_mng);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<AnfNodePtr, AnfNodePtr> to_be_replaced;
|
|
|
|
|
for (auto scalar_tensor_node : scalar_tensors) {
|
|
|
|
|
auto scalar = ConvertToScalarTensor(scalar_tensor_node);
|
|
|
|
|
auto format = GetFormat(scalar_tensor_node);
|
|
|
|
|
auto dst_shape_vec = GetShape(scalar_tensor_node);
|
|
|
|
|
AnfNodePtrList new_broadcast_inputs = {NewValueNode(prim::kPrimBroadcastTo), scalar};
|
|
|
|
|
auto broadcast_node = CreateCNode(new_broadcast_inputs, fg,
|
|
|
|
|
{.format = format, .shape = dst_shape_vec, .type = GetType(scalar_tensor_node)});
|
|
|
|
|
auto device_shape = GetDeviceShape(scalar_tensor_node);
|
|
|
|
|
SetNodeAttrSafely("shape", MakeValue(device_shape), broadcast_node);
|
|
|
|
|
to_be_replaced[scalar_tensor_node] = broadcast_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto [old_value_node, new_node] : to_be_replaced) {
|
|
|
|
|
sub_mng->Replace(old_value_node, new_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
|
|
|
|
@ -128,20 +207,28 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
|
|
|
|
|
auto nodes = TopoSort(fg->get_return());
|
|
|
|
|
|
|
|
|
|
OrderedMap<ValuePtr, AnfNodePtrList> vmap;
|
|
|
|
|
std::vector<AnfNodePtr> scalar_tensors;
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
|
|
|
auto tnode = inputs[i];
|
|
|
|
|
const auto &tnode = inputs[i];
|
|
|
|
|
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
|
|
|
|
|
if (tensor && (tensor->DataSize() > 1)) {
|
|
|
|
|
if (tensor == nullptr || tensor->DataSize() == 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (TensorElementAllTheSame(tensor)) {
|
|
|
|
|
scalar_tensors.emplace_back(tnode);
|
|
|
|
|
} else {
|
|
|
|
|
vmap[GetValueNode(tnode)].push_back(tnode);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ReplaceTensorWithScalar(fg, scalar_tensors);
|
|
|
|
|
|
|
|
|
|
if (vmap.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -169,6 +256,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
|
|
|
|
|
|
|
|
|
|
inputs.push_back(vnode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -660,7 +748,22 @@ ShapeVector GetShape(const AnfNodePtr &node) {
|
|
|
|
|
if (shape == nullptr || !shape->isa<abstract::Shape>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
return shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
auto shape_vec = shape->cast<abstract::ShapePtr>()->shape();
|
|
|
|
|
if (shape_vec.empty()) {
|
|
|
|
|
shape_vec.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
return shape_vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ShapeVector GetDeviceShape(const AnfNodePtr &node) {
|
|
|
|
|
ShapeVector res_device_shape;
|
|
|
|
|
auto device_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
|
|
|
|
|
if (device_shape.empty()) {
|
|
|
|
|
res_device_shape.push_back(1);
|
|
|
|
|
} else {
|
|
|
|
|
std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(res_device_shape), SizeToLong);
|
|
|
|
|
}
|
|
|
|
|
return res_device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
|
|
|
|
|