diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 1e1daf6245..5b0f02bed1 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -160,26 +160,13 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (real_input_node->isa()) { continue; } - if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { - continue; - } std::shared_ptr builder = std::make_shared(); // we set special device info of a input tensor. - bool is_ref = false; - auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); - if (op_info != nullptr) { - is_ref = op_info->is_ref(); - } - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode && - AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { - continue; - } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index f95882994e..b573cb33bb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -298,9 +298,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { - if (node->isa()) { - return true; - } else if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + if (node->isa() || node->isa()) { return true; } return false;