!1845 convert parameter & valuenode's device type

Merge pull request !1845 from lianliguang/master
pull/1845/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 10ebd81b10

@ -176,7 +176,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
} }

@ -298,7 +298,10 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>() || node->isa<Parameter>()) { if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true; return true;
} }
return false; return false;

Loading…
Cancel
Save