!9309 support 3d format

From: @liubuyu
Reviewed-by: 
Signed-off-by:
pull/9309/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4f032cf3f8

@ -95,6 +95,16 @@ constexpr auto kJSocVersion = "socVersion";
constexpr auto kSOC_VERSION = "SOC_VERSION";
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "Check if need change default format";
if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
return attr == kOpFormat_NCDHW;
}
return false;
}
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
nlohmann::json *kernel_json) {
MS_EXCEPTION_IF_NULL(anf_node);
@ -161,10 +171,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode>
bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
const string &op_input_name, size_t input_i,
std::vector<nlohmann::json> *input_list) {
auto def_format = kOpFormat_NCHW;
auto dtype = GetDeviceInputType(anf_node, real_input_index);
auto format = GetDeviceInputFormat(anf_node, real_input_index);
auto shape = GetDeviceInputShape(anf_node, real_input_index);
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
def_format = kOpFormat_NCDHW;
}
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
@ -172,7 +186,7 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode>
input_desc_json[kJDtype] = dtype;
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
input_desc_json[kJOriShape] = ori_shape;
input_desc_json[kJOriFormat] = kOpFormat_NCHW;
input_desc_json[kJOriFormat] = def_format;
input_desc_json[kJShape] = shape;
input_desc_json[kJFormat] = format;
input_desc_json[kJValid] = value;
@ -379,6 +393,10 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
std::vector<nlohmann::json> *output_list) {
MS_EXCEPTION_IF_NULL(output_idx);
MS_EXCEPTION_IF_NULL(output_list);
auto def_format = kOpFormat_NCHW;
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
def_format = kOpFormat_NCDHW;
}
for (size_t i = 0; i < output_obj_num; i++) {
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
@ -397,7 +415,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
output_obj[kJShape] = shape;
output_obj[kJFormat] = format;
output_obj[kJOriShape] = ori_shape;
output_obj[kJOriFormat] = kOpFormat_NCHW;
output_obj[kJOriFormat] = def_format;
output_obj[kJName] = output_ptr->name();
output_obj[kJValid] = true;
output_obj[kJParamType] = output_ptr->param_type();
@ -580,6 +598,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod
format = kOpFormat_NCHW;
}
}
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
return format;
}
@ -619,6 +640,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
format = kOpFormat_NCHW;
}
}
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
return format;
}
@ -818,6 +842,10 @@ void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) {
void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) {
GenPreDescJson(output_desc);
auto def_format = kOpFormat_NCHW;
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
def_format = kOpFormat_NCDHW;
}
// data_type
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx);
(*output_desc)[kJDataType] = tbe::TypeIdToString(type_id);
@ -828,7 +856,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
}
(*output_desc)[kJName] = output_desc_name;
// ori_format
(*output_desc)[kJOriFormat] = kOpFormat_NCHW;
(*output_desc)[kJOriFormat] = def_format;
// ori_shape
auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx);
if (ori_shape.empty()) {

@ -248,13 +248,57 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
return false;
if (IsSameShape()) {
if (!HasScalarInput()) {
AssignSupportFormat(kOpFormat_NDC1HWC0, support_format);
return true;
}
return false;
}
SupportFormatItem input_support_format;
SupportFormatItem output_support_format;
if (HasScalarInput()) {
for (const auto &shape : input_shapes_) {
if (IsScalarShape(shape)) {
input_support_format.emplace_back(kOpFormat_NCDHW);
} else if (!Is5DShape(shape)) {
return false;
} else if (shape[kChannelC] % kAlignmented16 != 0) {
return false;
} else {
input_support_format.emplace_back(kOpFormat_NDC1HWC0);
}
}
} else {
for (const auto &shape : input_shapes_) {
if (!Is5DShape(shape)) {
return false;
}
}
auto shape_tmp = input_shapes_[0];
auto broadcast_c_axis = std::any_of(
input_shapes_.begin(), input_shapes_.end(),
[&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); });
if (broadcast_c_axis) {
MS_LOG(INFO) << "This node broadcast c channel.";
return false;
}
input_support_format.assign(input_num_, kOpFormat_NDC1HWC0);
}
GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format);
support_format->input_format.emplace_back(input_support_format);
support_format->output_format.emplace_back(output_support_format);
return true;
}
bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const {
return shape.size() == kShape4dDims;
}
bool TbeKernelBroadCastSelecter::Is5DShape(const std::vector<size_t> &shape) const {
return shape.size() == kShape5dDims;
}
bool TbeKernelBroadCastSelecter::IsSameShape() const {
auto shape = input_shapes_.begin();
for (const auto &item : input_shapes_) {

@ -40,6 +40,7 @@ class TbeKernelBroadCastSelecter {
bool IsSameShape() const;
void PadScalarShape(std::vector<size_t> *shape) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
bool Is5DShape(const std::vector<size_t> &shape) const;
bool IsScalarShape(const std::vector<size_t> &shape) const;
bool HasScalarInput() const;
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;

@ -72,8 +72,18 @@ bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format)
bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const {
MS_EXCEPTION_IF_NULL(support_format);
// like to 5HD
return false;
if (!Is5DShape(input_shape_)) {
return false;
}
if (!keep_dims_ || axis_.empty()) {
return false;
}
auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); });
if (reduce_c_axis) {
return false;
}
AssignSupportFormat(kOpFormat_NDC1HWC0, support_format);
return true;
}
bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const {
@ -142,6 +152,8 @@ void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_for
bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; }
bool TbeKernelReduceSelecter::Is5DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape5dDims; }
void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const {
MS_EXCEPTION_IF_NULL(shape);
if (shape->empty()) {

@ -39,6 +39,7 @@ class TbeKernelReduceSelecter {
void GetReduceAttrKeepDim();
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
bool Is4DShape(const std::vector<size_t> &shape) const;
bool Is5DShape(const std::vector<size_t> &shape) const;
void PadScalarShape(std::vector<size_t> *shape) const;
CNodePtr cnode_ptr_;
std::vector<size_t> input_shape_{};

@ -187,6 +187,9 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ.";
}
if (!broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format)) {
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support NDC1HWC0.";
}
PrintSupportedFormat(support_format);
OpInfo op_info_new;
CreateNewOpInfo(op_info, support_format, &op_info_new);
@ -281,10 +284,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
return true;
}
// not support format:
// 1 NDHWC with shape size != 5
// 3 !NDHWC with shape size > 4
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
// 1 NCDHW with shape size != 5
if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) {
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
return false;
}

@ -32,7 +32,7 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs;
@ -70,9 +70,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr trans_data = nullptr;
MS_EXCEPTION_IF_NULL(node);
// Init
std::string default_format = kOpFormat_DEFAULT;
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
if (attr == kOpFormat_NCDHW) {
default_format = kOpFormat_NCDHW;
}
}
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT;
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
: AnfAlgo::GetOutputReshapeType(node, insert_index);
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)

@ -369,6 +369,26 @@ void KernelGraph::CheckLoop() {
}
}
void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
kernel_build_info_builder->SetOutputsFormat({format});
kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
}
void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const {
MS_EXCEPTION_IF_NULL(node);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) {
auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i);
MS_EXCEPTION_IF_NULL(in_node);
if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) {
ReSetParameterValueNodeFormatAndType(in_node, format);
}
}
}
CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
auto cnode = FuncGraph::NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode);
@ -378,6 +398,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
SetKernelInfoForNode(cnode);
if (AnfAlgo::HasNodeAttr("io_format", cnode)) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
if (attr == kOpFormat_NCDHW) {
ResetInFormat(cnode, kOpFormat_NCDHW);
}
}
AnfAlgo::SetGraphId(graph_id_, cnode.get());
return cnode;
}

@ -273,6 +273,7 @@ class KernelGraph : public FuncGraph {
// remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
void SetKernelInfoForNode(const AnfNodePtr &node) const;
void ResetInFormat(const AnfNodePtr &node, const std::string &format) const;
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes);

@ -266,6 +266,41 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
return device_shape;
}
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
// NCDHW
if (shape.size() != 5) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
}
std::vector<size_t> device_shape;
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
const size_t C0 = kCubeSize;
device_shape.push_back(shape[0]);
device_shape.push_back(shape[2]);
device_shape.push_back(C1);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[4]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
// NCDHW -> Frac_Z_3D
if (shape.size() != 5) {
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
}
std::vector<size_t> device_shape;
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
device_shape.push_back(shape[2]);
device_shape.push_back(C1);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[4]);
device_shape.push_back(N1);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
@ -310,7 +345,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
return device_shape;
}
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < kNdhwc) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
@ -405,7 +440,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
{kOpFormat_NDHWC, NdhwcDeviceShape}};
{kOpFormat_NCDHW, NcdhwDeviceShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
@ -441,7 +478,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != kNchwDims) {
if (shape.size() != kNchwDims && shape.size() != 5) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
const std::map<std::string, FormatTransfer> format_trans_map{
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
MS_LOG(DEBUG) << "Start trans format.";
if (abstract::TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype..";
@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) {
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
{kOpFormat_FRAC_NZ, FracNzToNchw},
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
const std::map<std::string, FormatTransfer> format_trans_map{
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
MS_LOG(DEBUG) << "Start trans format.";
if (abstract::TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype..";
@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
}
return true;
}
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false;
}
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto c1 = args.device_shape[2];
auto c0 = args.device_shape[5];
const size_t cdhw = c * d * h * w;
const size_t dhw = d * h * w;
const size_t hw = h * w;
const size_t dc1hwc0 = d * c1 * h * w * c0;
const size_t c1hwc0 = c1 * h * w * c0;
const size_t hwc0 = h * w * c0;
const size_t wc0 = w * c0;
for (size_t n_i = 0; n_i < n; n_i++) {
size_t n_head = n_i * cdhw;
for (size_t c_i = 0; c_i < c; c_i++) {
size_t c_head = n_head + c_i * dhw;
for (size_t d_i = 0; d_i < d; d_i++) {
size_t d_head = c_head + d_i * hw;
for (size_t h_i = 0; h_i < h; h_i++) {
size_t h_head = d_head + h_i * w;
for (size_t w_i = 0; w_i < w; w_i++) {
size_t dst_i = h_head + w_i;
size_t c1_i = c_i / c0;
size_t c0_i = c_i % c0;
auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
SetData(size, false, src_idx, dst_i, args, result);
}
}
}
}
}
return true;
}
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false;
}
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto c0 = kCubeSize;
auto c1 = DivCeil(c, c0);
const size_t cdhw = c * d * h * w;
const size_t dhw = d * h * w;
const size_t hw = h * w;
const size_t dc1hwc0 = d * c1 * h * w * c0;
const size_t c1hwc0 = c1 * h * w * c0;
const size_t hwc0 = h * w * c0;
const size_t wc0 = w * c0;
for (size_t n_i = 0; n_i < n; n_i++) {
size_t n_head = n_i * dc1hwc0;
for (size_t d_i = 0; d_i < d; d_i++) {
size_t d_head = n_head + d_i * c1hwc0;
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
size_t c1_head = d_head + c1_i * hwc0;
for (size_t h_i = 0; h_i < h; h_i++) {
size_t h_head = c1_head + h_i * wc0;
for (size_t w_i = 0; w_i < w; w_i++) {
size_t w_head = h_head + w_i * c0;
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
size_t dst_i = c0_i + w_head;
size_t c_i = c0_i + c1_i * c0;
size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
auto pad_zero = c_i >= c;
SetData(size, pad_zero, src_i, dst_i, args, result);
}
}
}
}
}
}
return true;
}
} // namespace trans
} // namespace mindspore

@ -66,6 +66,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result);
// device to host
bool ToNchw(const FormatArgs &args, void *result);
bool FracZToNchw(const FormatArgs &args, void *result);
@ -73,6 +75,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
} // namespace trans
} // namespace mindspore

@ -292,7 +292,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) {
if (type_id_ == type) {
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);
sync_ok = true;
@ -454,7 +454,7 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
device_shape = trans::TransShapeToDevice(*host_shape, format_);
} else {
if (host_shape_.empty()) {
@ -531,7 +531,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
if (host_shape.empty()) {
host_shape.emplace_back(1);
}
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) {
if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) {
if (type_id_ == type) {
SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE);
sync_ok = true;
@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);

@ -81,6 +81,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
string priority_matched_format = kOpFormat_NC1HWC0;
bool is_init = false;
bool need_change_nd = false;
bool is_5d_input = false;
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
@ -93,14 +94,21 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
priority_matched_format = kOpFormat_DEFAULT;
}
auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
if (input_shape_size == 5) {
is_5d_input = true;
}
need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
}
if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
priority_matched_format = kOpFormat_DEFAULT;
}
if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
priority_matched_format = kOpFormat_NDC1HWC0;
}
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
return priority_matched_format;
}
/**
* Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
* if equal then next num location
@ -157,7 +165,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
}
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) {
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT ||
kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) {
(*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
}
}
@ -376,7 +385,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
output_format = {selected_kernel_info->GetInputFormat(input_index)};
}
builder->SetOutputsFormat(output_format);
@ -386,7 +397,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
output_format = {selected_kernel_info->GetInputFormat(input_index)};
}
builder->SetOutputsFormat(output_format);

@ -386,11 +386,23 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
constexpr auto kOpFormat_NDHWC = "NDHWC";
constexpr auto kOpFormat_NCDHW = "NCDHW";
constexpr auto kOpFormat_DHWNC = "DHWNC";
constexpr auto kOpFormat_DHWCN = "DHWCN";
constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0";
constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D";
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
const std::set<std::string> kOpFormatList = {
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM};
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0,
kOpFormat_ND, kOpFormat_NCHW,
kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04,
kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM,
kOpFormat_NDC1HWC0, kOpFormat_NCDHW,
kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC,
kOpFormat_DHWCN};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kApplyMomentumOpName,
@ -427,8 +439,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kSparseApplyProximalAdagradOpName};
const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM};
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};

Loading…
Cancel
Save