|
|
|
|
@ -36,19 +36,20 @@ void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr
|
|
|
|
|
map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
|
|
|
|
|
bool GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info,
|
|
|
|
|
std::vector<std::pair<KernelAttr, GpuKernelCreater>> *iter_second,
|
|
|
|
|
size_t attr_index) {
|
|
|
|
|
if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) {
|
|
|
|
|
if (!iter_second->at(attr_index).first.GetAllSame()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) {
|
|
|
|
|
if (!iter_second->at(attr_index).first.GetAllSame()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) {
|
|
|
|
|
@ -119,7 +120,9 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
|
|
|
|
|
CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index);
|
|
|
|
|
if (!CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bool flag = true;
|
|
|
|
|
auto attr_size = (&(iter->second))->at(attr_index).first.GetInputSize();
|
|
|
|
|
// data type matching check of all input parameters of kernel
|
|
|
|
|
|