|
|
|
@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (param_type == "dynamic") {
|
|
|
|
|
if (dyn_input_sizes.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
|
|
|
|
@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
|
inputs_device_type.push_back(type_id);
|
|
|
|
|
inputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
dyn_input_idx++;
|
|
|
|
|
} else if (param_type == "required") {
|
|
|
|
@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
|
inputs_device_type.push_back(type_id);
|
|
|
|
|
inputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
} else {
|
|
|
|
|
if (kernel_info_index < real_input_num) {
|
|
|
|
|
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
|
|
|
|
@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
|
|
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
|
inputs_device_type.push_back(type_id);
|
|
|
|
|
inputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder->SetInputReshapeType(reshape_types);
|
|
|
|
@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
|
|
|
|
|
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t output_num = 0;
|
|
|
|
|
if (output->param_type() == "dynamic") {
|
|
|
|
|
if (outputs.size() > 1) {
|
|
|
|
@ -467,13 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
|
|
|
|
|
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
|
|
|
outputs_device_type.push_back(type_id);
|
|
|
|
|
outputs_format.push_back(formats[builder_idex]);
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
output_idx++;
|
|
|
|
|
}
|
|
|
|
|
std::vector<Axis> reshape_type;
|
|
|
|
|
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
reshape_types.push_back(reshape_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder->SetOutputReshapeType(reshape_types);
|
|
|
|
|