|
|
|
@ -176,7 +176,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
|
|
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
|
|
|
|
}
|
|
|
|
|