reshape type for 3d nodes

pull/12968/head
liubuyu 4 years ago
parent 659b912f6d
commit 518818fbef

@ -66,7 +66,7 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
if (input_reshape_type_.empty()) {
return {};
}
@ -77,7 +77,7 @@ std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const
return input_reshape_type_[input_index];
}
std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
if (output_reshape_type_.empty()) {
return {};
}
@ -175,14 +175,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
const std::vector<std::vector<Axis>> &input_reshape_type) {
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->input_reshape_type_ = input_reshape_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
const std::vector<std::vector<Axis>> &output_reshape_type) {
const std::vector<std::string> &output_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->output_reshape_type_ = output_reshape_type;
}
@ -206,8 +205,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
}
kernel_build_info_->outputs_format_[index] = format;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
size_t index) {
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) {
if (index >= kernel_build_info_->input_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
@ -215,7 +213,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vec
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
size_t index) {
if (index >= kernel_build_info_->output_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";

@ -57,13 +57,13 @@ class KernelBuildInfo {
TypeId GetOutputDeviceType(size_t output_index) const;
std::vector<Axis> GetInputReshapeType(size_t input_index) const;
std::string GetInputReshapeType(size_t input_index) const;
bool IsInputDefaultPadding() const;
bool IsOutputDefaultPadding() const;
std::vector<Axis> GetOutputReshapeType(size_t input_index) const;
std::string GetOutputReshapeType(size_t input_index) const;
const std::string &GetOriginDataFormat() const;
@ -75,9 +75,9 @@ class KernelBuildInfo {
const std::vector<TypeId> &GetAllOutputDeviceTypes() const;
std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;
std::vector<std::string> GetAllOutputReshapeType() const;
std::vector<std::vector<Axis>> GetAllInputReshapeType() const;
std::vector<std::string> GetAllInputReshapeType() const;
OpPattern op_pattern() const { return op_pattern_; }
@ -106,8 +106,8 @@ class KernelBuildInfo {
std::vector<std::string> inputs_format_;
OpPattern op_pattern_;
std::vector<std::string> outputs_format_;
std::vector<std::vector<Axis>> input_reshape_type_;
std::vector<std::vector<Axis>> output_reshape_type_;
std::vector<std::string> input_reshape_type_;
std::vector<std::string> output_reshape_type_;
std::vector<TypeId> inputs_device_type_;
std::vector<TypeId> outputs_device_type_;
FusionType fusion_type_;
@ -151,9 +151,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
void SetFusionType(FusionType fusion_type);
@ -165,9 +165,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputFormat(const std::string &format, size_t index);
void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);
void SetInputReshapeType(const std::string &input_reshape_type, size_t index);
void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);
void SetOutputReshapeType(const std::string &output_reshape_type, size_t index);
void SetInputDeviceType(const TypeId &input_device_type, size_t index);

@ -99,7 +99,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
SetTbeBuildCommonInfo(op_info, &builder);
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_device_type;
std::vector<std::vector<Axis>> inputs_reshape_type;
std::vector<std::string> inputs_reshape_type;
// input
if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
&inputs_format, &inputs_device_type, &inputs_reshape_type)) {
@ -111,7 +111,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
// output
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type;
std::vector<std::vector<Axis>> outputs_reshape_type;
std::vector<std::string> outputs_reshape_type;
if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
&outputs_format, &outputs_device_type, &outputs_reshape_type)) {
break;
@ -290,7 +290,7 @@ std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,
std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) {
std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types) {
MS_EXCEPTION_IF_NULL(formats);
MS_EXCEPTION_IF_NULL(device_types);
MS_EXCEPTION_IF_NULL(reshape_types);
@ -306,8 +306,7 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
kernel_build_info_format = io_info_item->formats()[kernel_build_info_index];
}
const std::string &io_param_type = io_info_item->param_type();
std::vector<Axis> reshape_type;
StringToAxisVector(io_info_item->reshape_type(), &reshape_type);
auto reshape_type = io_info_item->reshape_type();
if (io_param_type == kParamTypeDynamic) {
// dynamic io
if (is_input) {
@ -355,28 +354,6 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
return true;
}
void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
MS_EXCEPTION_IF_NULL(reshape_type_vec);
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(N);
break;
case 'C':
reshape_type_vec->push_back(C);
break;
case 'H':
reshape_type_vec->push_back(H);
break;
case 'W':
reshape_type_vec->push_back(W);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
}
}
}
void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
mindspore::kernel::OpIOInfo *op_io_info_new) {

@ -52,8 +52,7 @@ class TbeKernelSelect {
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,
std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types);
static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types);
static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new);
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info,
const std::vector<std::vector<std::string>> &support_format_item, size_t index,

@ -187,8 +187,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
: AnfAlgo::GetOutputReshapeType(node, insert_index);
std::string padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
: AnfAlgo::GetOutputReshapeType(node, insert_index);
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
@ -200,8 +200,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
AnfAlgo::GetInputReshapeType(node, insert_index));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
@ -222,8 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
}
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
const TypeId &type_id) {
const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
MS_EXCEPTION_IF_NULL(ori_build_info);
@ -249,9 +248,10 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
trans_node.get());
AnfAlgo::SetOutputInferTypeAndShape(
{AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShape(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputFormat(input, 0), padding_axis)},
trans_node.get());
} else {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
@ -273,7 +273,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type,
const std::vector<Axis> &reshape_type) {
const std::string &reshape_type) {
MS_EXCEPTION_IF_NULL(func_graph);
std::string input_format = format;
std::string output_format = format;

@ -88,7 +88,7 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {},
const AnfNodePtr &trans_data, const std::string &reshape_type = {""},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
@ -97,7 +97,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type,
const std::vector<Axis> &reshape_type = std::vector<Axis>{});
const std::string &reshape_type = std::string{});
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);

@ -586,7 +586,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
}
std::vector<Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
@ -642,7 +642,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
}
@ -655,12 +655,12 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index:" << input_idx
@ -681,7 +681,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod
return build_info->GetInputReshapeType(input_idx);
}
std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "

@ -122,7 +122,7 @@ class AnfRuntimeAlgorithm {
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.
static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes.
@ -132,9 +132,9 @@ class AnfRuntimeAlgorithm {
// get input shapes which will built and run in device
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
// Get Input Padding Axis
static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
// Get Output Padding Axis
static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node

@ -21,6 +21,7 @@
#include "abstract/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "runtime/device/convert_tensor_utils.h"
#include "utils/convert_utils.h"
#include "utils/log_adapter.h"
@ -28,7 +29,7 @@
namespace mindspore {
namespace trans {
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
switch (size) {
case 1:
@ -343,7 +344,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
}
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < kNdhwc) {
if (shape.size() < kNcdhw) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
return shape;
@ -388,6 +389,20 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) {
return false;
}
std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
const std::string &pad_index) {
std::vector<size_t> host_shape;
if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
if (shape.size() >= kNcdhw) {
return shape;
}
host_shape = trans::PaddingShapeTo5d(shape, pad_index);
} else {
host_shape = trans::PaddingShapeTo4d(shape, pad_index);
}
return host_shape;
}
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
ShapeVector shape;
@ -409,14 +424,84 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
} else {
host_shape = AnfAlgo::GetOutputInferShape(node, index);
}
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
auto format = AnfAlgo::GetOutputFormat(node, index);
if (trans::IsNeedPadding(format, host_shape.size())) {
host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
}
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
return shape;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
MS_EXCEPTION_IF_NULL(reshape_type_vec);
if (reshape_type_str.empty()) {
MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
return;
}
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(N);
break;
case 'C':
reshape_type_vec->push_back(C);
break;
case 'H':
reshape_type_vec->push_back(H);
break;
case 'W':
reshape_type_vec->push_back(W);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
}
}
}
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
MS_EXCEPTION_IF_NULL(reshape_type_vec);
if (reshape_type_str.empty()) {
MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
return;
}
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(N_ncdhw);
break;
case 'C':
reshape_type_vec->push_back(C_ncdhw);
break;
case 'D':
reshape_type_vec->push_back(D_ncdhw);
break;
case 'H':
reshape_type_vec->push_back(H_ncdhw);
break;
case 'W':
reshape_type_vec->push_back(W_ncdhw);
break;
default:
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
}
}
}
std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_str) {
std::vector<Axis5D> padding_axis;
StringToAxisVector5D(padding_str, &padding_axis);
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo5dDefault(shape);
}
std::vector<size_t> shape_5d(kNcdhw, 1);
for (size_t index = 0; index < padding_axis.size(); index++) {
shape_5d[padding_axis[index]] = shape[index];
}
return shape_5d;
}
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_str) {
std::vector<Axis> padding_axis;
StringToAxisVector4D(padding_str, &padding_axis);
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
@ -427,6 +512,38 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
return shape_4d;
}
std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) {
if (shape.size() >= kNcdhw) {
return shape;
}
std::vector<size_t> shape_5d(kNcdhw, 1);
switch (shape.size()) {
case 0:
return shape_5d;
case 1:
shape_5d[1] = shape[0];
break;
case 2:
shape_5d[1] = shape[0];
shape_5d[2] = shape[1];
break;
case 3:
shape_5d[1] = shape[0];
shape_5d[2] = shape[1];
shape_5d[3] = shape[2];
break;
case 4:
shape_5d[1] = shape[0];
shape_5d[2] = shape[1];
shape_5d[3] = shape[2];
shape_5d[4] = shape[3];
break;
default:
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
}
return shape_5d;
}
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
@ -475,10 +592,13 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != kNchwDims && shape.size() != 5) {
if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
temp_shape = PaddingShapeTo5dDefault(shape);
}
auto iter = device_shape_map.find(format);
if (iter == device_shape_map.end()) {
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";

@ -30,6 +30,13 @@
namespace mindspore {
namespace trans {
enum Axis5D : int {
N_ncdhw = 0,
C_ncdhw,
D_ncdhw,
H_ncdhw,
W_ncdhw,
};
struct TypeIdArgs {
const void *data;
size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d
@ -50,7 +57,13 @@ struct FormatArgs {
size_t CubeSizeByType(const TypeId data_type);
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
const std::string &pad_index = {""});
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_axis = {""});
std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_axis = {""});
std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape);
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
bool IsNeedPadding(const std::string &format, const size_t shape_size);
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);

@ -475,7 +475,7 @@ std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *hos
device_shape = trans::TransShapeToDevice(*host_shape, format_);
} else {
if (host_shape_.empty()) {
*host_shape = trans::PaddingShapeTo4d(*host_shape);
*host_shape = trans::PaddingShape(*host_shape, format_);
} else {
host_shape->clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize);
@ -595,11 +595,10 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 ||
format_ == kOpFormat_FRACTAL_Z_3D) {
if (format_ == kOpFormat_FRAC_NZ) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);
host_shape = trans::PaddingShape(host_shape, format_);
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {

@ -68,7 +68,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
auto format = AnfAlgo::GetOutputFormat(node, output_index);
if (shape.empty() && format != kOpFormat_DEFAULT) {
shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::TransShapeToDevice(shape, format);
}
// scalar's output shape is a empty vector

@ -303,6 +303,7 @@ constexpr auto kAttrFactor = "factor";
constexpr auto kAttrIsRef = "isRef";
constexpr auto kAttrDataShape = "data_shape";
constexpr auto kAttrFormat = "format";
constexpr auto kAttrReshapeType = "reshape_type";
constexpr auto kAttrAxis = "axis";
constexpr auto kAttrKeepDims = "keep_dims";
constexpr auto kAttrShapeGamma = "shape_gamma";

@ -285,8 +285,8 @@ class Tensor : public MetaTensor {
DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
std::vector<Axis> padding_type() const { return padding_type_; }
void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; }
std::string padding_type() const { return padding_type_; }
std::string id() const { return id_; }
TypePtr cast_dtype() { return cast_dtype_; }
@ -366,7 +366,7 @@ class Tensor : public MetaTensor {
bool cache_enable_{false};
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
std::vector<Axis> padding_type_;
std::string padding_type_{""};
TypePtr cast_dtype_{nullptr};
std::shared_ptr<DeviceEvent> device_event_{nullptr};
};

@ -50,8 +50,8 @@ class TestHWInsertTransOp : public BackendCommon {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({format, format});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({"", ""});
builder.SetOutputsReshapeType({""});
builder.SetOutputsFormat({format});
builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -72,8 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{},{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({"", ""});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({format, format});
@ -92,8 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});

@ -53,8 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({"", ""});
builder.SetOutputsReshapeType({""});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -80,8 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(tuple_getitem1, max_pool, 0, true);
kg->AddInternalOutput(tuple_getitem2, max_pool, 1, true);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}, {}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({"", ""});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
@ -103,8 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
};

@ -51,8 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@ -60,8 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
@ -79,8 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@ -88,8 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@ -125,8 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);
@ -173,8 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);

@ -58,8 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@ -67,8 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@ -97,8 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
EXPECT_NE(transpose, nullptr);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});

@ -56,8 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
~MockEliminate5To4And4To5KernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
@ -104,8 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({"", ""});
builder.SetOutputsReshapeType({""});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -171,8 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}, {}});
builder.SetInputsReshapeType({"", ""});
builder.SetOutputsReshapeType({"", ""});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -248,8 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsReshapeType({"", ""});
builder.SetOutputsReshapeType({""});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());

Loading…
Cancel
Save