|
|
|
@ -21,6 +21,7 @@
|
|
|
|
|
#include "abstract/utils.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
|
|
|
|
#include "runtime/device/convert_tensor_utils.h"
|
|
|
|
|
#include "utils/convert_utils.h"
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
@ -28,7 +29,7 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace trans {
|
|
|
|
|
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
|
|
|
|
|
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
|
|
|
|
|
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
|
|
|
|
switch (size) {
|
|
|
|
|
case 1:
|
|
|
|
@ -343,7 +344,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
if (shape.size() < kNdhwc) {
|
|
|
|
|
if (shape.size() < kNcdhw) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
|
|
|
|
}
|
|
|
|
|
return shape;
|
|
|
|
@ -388,6 +389,20 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
|
|
|
|
|
const std::string &pad_index) {
|
|
|
|
|
std::vector<size_t> host_shape;
|
|
|
|
|
if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
|
|
|
|
if (shape.size() >= kNcdhw) {
|
|
|
|
|
return shape;
|
|
|
|
|
}
|
|
|
|
|
host_shape = trans::PaddingShapeTo5d(shape, pad_index);
|
|
|
|
|
} else {
|
|
|
|
|
host_shape = trans::PaddingShapeTo4d(shape, pad_index);
|
|
|
|
|
}
|
|
|
|
|
return host_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
ShapeVector shape;
|
|
|
|
@ -409,14 +424,84 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
|
|
|
|
} else {
|
|
|
|
|
host_shape = AnfAlgo::GetOutputInferShape(node, index);
|
|
|
|
|
}
|
|
|
|
|
if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
|
|
|
|
|
host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
|
|
|
|
|
auto format = AnfAlgo::GetOutputFormat(node, index);
|
|
|
|
|
if (trans::IsNeedPadding(format, host_shape.size())) {
|
|
|
|
|
host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
|
|
|
|
|
}
|
|
|
|
|
std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
|
|
|
|
|
return shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
|
|
|
|
|
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reshape_type_vec);
|
|
|
|
|
if (reshape_type_str.empty()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &c : reshape_type_str) {
|
|
|
|
|
switch (c) {
|
|
|
|
|
case 'N':
|
|
|
|
|
reshape_type_vec->push_back(N);
|
|
|
|
|
break;
|
|
|
|
|
case 'C':
|
|
|
|
|
reshape_type_vec->push_back(C);
|
|
|
|
|
break;
|
|
|
|
|
case 'H':
|
|
|
|
|
reshape_type_vec->push_back(H);
|
|
|
|
|
break;
|
|
|
|
|
case 'W':
|
|
|
|
|
reshape_type_vec->push_back(W);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reshape_type_vec);
|
|
|
|
|
if (reshape_type_str.empty()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &c : reshape_type_str) {
|
|
|
|
|
switch (c) {
|
|
|
|
|
case 'N':
|
|
|
|
|
reshape_type_vec->push_back(N_ncdhw);
|
|
|
|
|
break;
|
|
|
|
|
case 'C':
|
|
|
|
|
reshape_type_vec->push_back(C_ncdhw);
|
|
|
|
|
break;
|
|
|
|
|
case 'D':
|
|
|
|
|
reshape_type_vec->push_back(D_ncdhw);
|
|
|
|
|
break;
|
|
|
|
|
case 'H':
|
|
|
|
|
reshape_type_vec->push_back(H_ncdhw);
|
|
|
|
|
break;
|
|
|
|
|
case 'W':
|
|
|
|
|
reshape_type_vec->push_back(W_ncdhw);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_str) {
|
|
|
|
|
std::vector<Axis5D> padding_axis;
|
|
|
|
|
StringToAxisVector5D(padding_str, &padding_axis);
|
|
|
|
|
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
|
|
|
|
return PaddingShapeTo5dDefault(shape);
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> shape_5d(kNcdhw, 1);
|
|
|
|
|
for (size_t index = 0; index < padding_axis.size(); index++) {
|
|
|
|
|
shape_5d[padding_axis[index]] = shape[index];
|
|
|
|
|
}
|
|
|
|
|
return shape_5d;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_str) {
|
|
|
|
|
std::vector<Axis> padding_axis;
|
|
|
|
|
StringToAxisVector4D(padding_str, &padding_axis);
|
|
|
|
|
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
|
|
|
|
return PaddingShapeTo4dByDefault(shape);
|
|
|
|
|
}
|
|
|
|
@ -427,6 +512,38 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
|
|
|
|
|
return shape_4d;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) {
|
|
|
|
|
if (shape.size() >= kNcdhw) {
|
|
|
|
|
return shape;
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> shape_5d(kNcdhw, 1);
|
|
|
|
|
switch (shape.size()) {
|
|
|
|
|
case 0:
|
|
|
|
|
return shape_5d;
|
|
|
|
|
case 1:
|
|
|
|
|
shape_5d[1] = shape[0];
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
shape_5d[1] = shape[0];
|
|
|
|
|
shape_5d[2] = shape[1];
|
|
|
|
|
break;
|
|
|
|
|
case 3:
|
|
|
|
|
shape_5d[1] = shape[0];
|
|
|
|
|
shape_5d[2] = shape[1];
|
|
|
|
|
shape_5d[3] = shape[2];
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
shape_5d[1] = shape[0];
|
|
|
|
|
shape_5d[2] = shape[1];
|
|
|
|
|
shape_5d[3] = shape[2];
|
|
|
|
|
shape_5d[4] = shape[3];
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
|
|
|
|
}
|
|
|
|
|
return shape_5d;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
|
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
|
|
|
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
|
|
|
@ -475,10 +592,13 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
if (shape.size() != kNchwDims && shape.size() != 5) {
|
|
|
|
|
if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
|
|
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
|
|
|
|
temp_shape = PaddingShapeTo4dByDefault(shape);
|
|
|
|
|
}
|
|
|
|
|
if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
|
|
|
|
temp_shape = PaddingShapeTo5dDefault(shape);
|
|
|
|
|
}
|
|
|
|
|
auto iter = device_shape_map.find(format);
|
|
|
|
|
if (iter == device_shape_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
|
|
|
|