From 1a6f62bd252e094cb8f0de972114c06d3720524a Mon Sep 17 00:00:00 2001 From: VectorSL Date: Mon, 27 Apr 2020 16:09:00 +0800 Subject: [PATCH] gpu update type check --- mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc index e38cc02e23..b00b5c263d 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc @@ -68,13 +68,18 @@ std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) return type_lists; } for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - std::string type_list = "["; + std::string type_list = "in["; auto attr = (iter->second)[attr_index].first; for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) { type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) + ((input_index == (attr.GetInputSize() - 1)) ? "" : " "); } - type_lists = type_lists + type_list + "] "; + type_list = type_list + "], out["; + for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) { + type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) + + ((input_index == (attr.GetOutputSize() - 1)) ? "" : " "); + } + type_lists = type_lists + type_list + "]; "; } return type_lists; }