|
|
@ -65,8 +65,8 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
|
|
|
|
|
|
|
|
|
|
|
|
std::string InitDefaultFormat(const AnfNodePtr &node) {
|
|
|
|
std::string InitDefaultFormat(const AnfNodePtr &node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) {
|
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
|
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat);
|
|
|
|
if (attr == kOpFormat_NCDHW) {
|
|
|
|
if (attr == kOpFormat_NCDHW) {
|
|
|
|
return kOpFormat_NCDHW;
|
|
|
|
return kOpFormat_NCDHW;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -127,11 +127,11 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
|
|
|
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
|
|
|
|
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
|
|
|
|
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
|
|
|
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
|
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
|
|
|
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return node;
|
|
|
|
return node;
|
|
|
@ -364,7 +364,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
|
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
|
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
|
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
|
|
|
// In graph kernel, we check parameter,
|
|
|
|
// In graph kernel, we check parameter,
|
|
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
|
|
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
|
|
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
|
|
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
|
|
|
new_inputs.push_back(cur_input);
|
|
|
|
new_inputs.push_back(cur_input);
|
|
|
|
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
|
|
|
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
|
|
|