|
|
|
@ -74,9 +74,6 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
|
platform::SetDeviceId(dev_id);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
// profile
|
|
|
|
|
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
|
|
|
|
platform::RecordEvent record_event(Type(), dev_ctx);
|
|
|
|
|
RunImpl(scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -485,6 +482,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// For profiling, don't move out of this function because that will result
|
|
|
|
|
// in the failure of multi-GPU profiling.
|
|
|
|
|
platform::RecordEvent record_event(Type(), dev_ctx);
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|