|
|
|
@ -32,6 +32,7 @@ namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
|
|
|
namespace {
|
|
|
|
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
|
|
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
|
|
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
|
|
|
@ -64,20 +65,30 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
CNodePtr trans_data = nullptr;
|
|
|
|
|
std::string InitDefaultFormat(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// Init
|
|
|
|
|
std::string default_format = kOpFormat_DEFAULT;
|
|
|
|
|
|
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
|
|
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
|
|
|
|
|
if (attr == kOpFormat_NCDHW) {
|
|
|
|
|
default_format = kOpFormat_NCDHW;
|
|
|
|
|
}
|
|
|
|
|
} else if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
|
|
|
|
auto out_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
|
if (k3DFormatSet.find(out_format) != k3DFormatSet.end()) {
|
|
|
|
|
default_format = kOpFormat_NCDHW;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return default_format;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
CNodePtr trans_data = nullptr;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// Init
|
|
|
|
|
std::string default_format = InitDefaultFormat(node);
|
|
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
|
|
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
|
|
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
|
|
|
|
|