|
|
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
|
|
|
namespace {
|
|
|
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
|
|
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
|
|
|
|
std::vector<AnfNodePtr> trans_inputs;
|
|
|
|
@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
AnfAlgo::SetNodeInput(node, input_node, index);
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
|
|
|
|
|
<< "when inserting the transdata node " << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
|
|
|
|
|
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
|
|
|
|
|
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
|
|
|
|
|
<< " To DefaultFormat , index: " << index;
|
|
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
|
|
|
|
@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|
|
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
|
|
|
|
<< node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
|
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
|
|
|
|
}
|
|
|
|
@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
}
|
|
|
|
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
|
|
|
|
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
|
|
|
|
|
} else {
|
|
|
|
|
// No need insert trans op.
|
|
|
|
|