|
|
|
@ -136,6 +136,12 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
|
platform::SetDeviceId(dev_id);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// For profiling, don't move out of this function because that will result
|
|
|
|
|
// in the failure of multi-GPU profiling.
|
|
|
|
|
platform::RecordEvent record_event(Type(), dev_ctx);
|
|
|
|
|
RunImpl(scope, place);
|
|
|
|
|
VLOG(10) << "+ " << DebugStringEx(&scope);
|
|
|
|
|
}
|
|
|
|
@ -639,9 +645,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// For profiling, don't move out of this function because that will result
|
|
|
|
|
// in the failure of multi-GPU profiling.
|
|
|
|
|
platform::RecordEvent record_event(Type(), dev_ctx);
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|