|
|
|
@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
ExecutionContext ctx(*this, scope, *dev_ctx);
|
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
for (auto& candidate : kKernelPriority) {
|
|
|
|
|
auto candidate_key =
|
|
|
|
|
OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate),
|
|
|
|
|
expected_kernel_key.data_layout_, std::get<1>(candidate));
|
|
|
|
|
|
|
|
|
|
if ((candidate_key == expected_kernel_key) ||
|
|
|
|
|
(kernels.count(candidate_key))) {
|
|
|
|
|
expected_kernel_key = candidate_key;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Scope& new_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
for (auto& var_name_item : this->Inputs()) {
|
|
|
|
@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
|
|
|
|
|
kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
|
|
|
|
|