update io_format to format

pull/11997/head
liubuyu 4 years ago
parent 9557bef491
commit 91f5b1b68e

@ -65,8 +65,8 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
std::string InitDefaultFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat);
if (attr == kOpFormat_NCDHW) {
return kOpFormat_NCDHW;
}
@ -127,11 +127,11 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
if (output_format == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
}
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
}
return node;
@ -364,7 +364,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
// In graph kernel, we check parameter,
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
new_inputs.push_back(cur_input);
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {

@ -20,6 +20,7 @@
#include <set>
#include "base/core_ops.h"
#include "ir/param_info.h"
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/kernel_build_info.h"
@ -400,8 +401,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
SetKernelInfoForNode(cnode);
if (AnfAlgo::HasNodeAttr("io_format", cnode)) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat);
if (attr == kOpFormat_NCDHW) {
ResetInFormat(cnode, kOpFormat_NCDHW);
}

Loading…
Cancel
Save