!9547 bug fix for 3d format support

From: @liubuyu
Reviewed-by: @jjfeing,@kisnwang
Signed-off-by: @kisnwang
pull/9547/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit da87017671

@ -32,6 +32,7 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
@ -64,20 +65,30 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
}
}
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
std::string InitDefaultFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// Init
std::string default_format = kOpFormat_DEFAULT;
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
if (attr == kOpFormat_NCDHW) {
default_format = kOpFormat_NCDHW;
}
} else if (node->isa<ValueNode>() || node->isa<Parameter>()) {
auto out_format = AnfAlgo::GetOutputFormat(node, 0);
if (k3DFormatSet.find(out_format) != k3DFormatSet.end()) {
default_format = kOpFormat_NCDHW;
}
}
return default_format;
}
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
CNodePtr trans_data = nullptr;
MS_EXCEPTION_IF_NULL(node);
// Init
std::string default_format = InitDefaultFormat(node);
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;

@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || kOpFormat_NDC1HWC0) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);

Loading…
Cancel
Save