|
|
|
@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
|
|
|
|
|
MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
|
|
|
|
|
for (size_t i = 0; i < inputs_static.size(); i++) {
|
|
|
|
|
inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
|
|
|
|
|
inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type());
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < outputs_static.size(); j++) {
|
|
|
|
|
outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
|
|
|
|
|
outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type());
|
|
|
|
|
}
|
|
|
|
|
op_info_new_ptr->set_inputs_ptr(inputs_dyn);
|
|
|
|
|
op_info_new_ptr->set_outputs_ptr(outputs_dyn);
|
|
|
|
@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
|
|
|
|
|
op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
|
|
|
|
|
for (const auto &c : reshape_type_str) {
|
|
|
|
|
switch (c) {
|
|
|
|
|
case 'N':
|
|
|
|
|
reshape_type_vec->push_back(kernel::N);
|
|
|
|
|
break;
|
|
|
|
|
case 'C':
|
|
|
|
|
reshape_type_vec->push_back(kernel::C);
|
|
|
|
|
break;
|
|
|
|
|
case 'H':
|
|
|
|
|
reshape_type_vec->push_back(kernel::H);
|
|
|
|
|
break;
|
|
|
|
|
case 'W':
|
|
|
|
|
reshape_type_vec->push_back(kernel::W);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
|
|
|
|
|
size_t builder_idex, const std::vector<int> &dyn_input_sizes,
|
|
|
|
|
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
|
|
|
@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
|
|
|
size_t kernel_info_cnt = inputs[0]->dtypes().size();
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<Axis>> reshape_types;
|
|
|
|
|
for (const auto &input : inputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
|
std::string param_type = input->param_type();
|
|
|
|
@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
inputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder->SetInputReshapeType(reshape_types);
|
|
|
|
|
builder->SetInputsDeviceType(inputs_device_type);
|
|
|
|
|
builder->SetInputsFormat(inputs_format);
|
|
|
|
|
return true;
|
|
|
|
@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputs[0]);
|
|
|
|
|
size_t kernel_info_cnt = outputs[0]->dtypes().size();
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<Axis>> reshape_types;
|
|
|
|
|
for (const auto &output : outputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (output_idx >= real_output_num) {
|
|
|
|
@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
|
|
|
|
|
outputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
output_idx++;
|
|
|
|
|
}
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder->SetOutputReshapeType(reshape_types);
|
|
|
|
|
builder->SetOutputsFormat(outputs_format);
|
|
|
|
|
builder->SetOutputsDeviceType(outputs_device_type);
|
|
|
|
|
return true;
|
|
|
|
@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
|
|
|
|
|
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
|
|
|
|
|
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
|
|
|
|
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
|
|
|
|
|
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04};
|
|
|
|
|
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
|
|
|
|
|
|
|
|
|
// if format is default, it remarkes support all format
|
|
|
|
|
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
|
|
|
@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
|
|
|
|
|
if (shape.empty()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (shape.size() > kShapeSupportFormatMap.size()) {
|
|
|
|
|
if (shape.size() > kShape4dDims) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
|
|
|
|
|
return true;
|
|
|
|
|
if (format == kOpFormat_FRAC_NZ && shape.size() < 2) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
|
|
|
|
|