!8157 Gpu support flexiable inputs kernel

Merge pull request !8157 from chenweifeng/gpu-variable-kernel-input
pull/8157/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 26165f863a

@ -36,19 +36,20 @@ void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr
map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater);
}
void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second,
size_t attr_index) {
if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!";
return false;
}
}
if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) {
if (!iter_second->at(attr_index).first.GetAllSame()) {
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!";
return false;
}
}
return true;
}
std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) {
@ -119,7 +120,9 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
}
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index);
if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) {
continue;
}
bool flag = true;
auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
// data type matching check of all input parameters of kernel

@ -57,7 +57,7 @@ class GpuKernelFactory {
GpuKernelFactory &operator=(const GpuKernelFactory &);
std::pair<bool, size_t> GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info);
void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
bool CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second, size_t attr_index);
// map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair.
std::map<std::string, std::vector<std::pair<KernelAttr, GpuKernelCreater>>> map_kernel_name_to_creater_;

Loading…
Cancel
Save