|
|
@ -485,9 +485,15 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", type_,
|
|
|
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// do data transform
|
|
|
|
Scope& new_scope = scope.NewScope();
|
|
|
|
Scope& new_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& var_name_item : this->Inputs()) {
|
|
|
|
for (auto& var_name_item : this->Inputs()) {
|
|
|
@ -520,8 +526,6 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
kernel_iter->second->Compute(
|
|
|
|
kernel_iter->second->Compute(
|
|
|
|
ExecutionContext(*this, new_scope, *new_dev_ctx));
|
|
|
|
ExecutionContext(*this, new_scope, *new_dev_ctx));
|
|
|
|