|
|
@ -36,7 +36,6 @@ class PreparedOp {
|
|
|
|
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
|
|
|
|
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
const platform::Place& place) {
|
|
|
|
const platform::Place& place) {
|
|
|
|
framework::Scope dummy_scope;
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
|
@ -52,7 +51,7 @@ class PreparedOp {
|
|
|
|
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = op.GetExpectedKernelType(
|
|
|
|
auto expected_kernel_key = op.GetExpectedKernelType(
|
|
|
|
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx));
|
|
|
|
framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx));
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|