|
|
|
@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
|
|
|
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
|
|
|
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
|
|
|
|
|
const TypeId &type_id) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_data);
|
|
|
|
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
|
|
|
@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
|
|
|
|
builder->SetInputReshapeType({reshape_type});
|
|
|
|
|
builder->SetOutputReshapeType({reshape_type});
|
|
|
|
|
builder->SetOutputsFormat({output_format});
|
|
|
|
|
if (type_id != kTypeUnknown) {
|
|
|
|
|
builder->SetOutputsDeviceType({type_id});
|
|
|
|
|
builder->SetInputsDeviceType({type_id});
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|