|
|
|
@ -160,26 +160,13 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|
|
|
|
if (real_input_node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
|
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
// we set special device info of a input tensor.
|
|
|
|
|
bool is_ref = false;
|
|
|
|
|
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
|
|
|
|
if (op_info != nullptr) {
|
|
|
|
|
is_ref = op_info->is_ref();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
|
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
|
|
|
|
}
|
|
|
|
|