|
|
|
@ -880,7 +880,14 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
if (!HasAttr(kEnableCacheRuntimeContext)) {
|
|
|
|
|
// To reduce the elapsed time of HasAttr, we use bool variable to record the
|
|
|
|
|
// result of HasAttr.
|
|
|
|
|
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
|
|
|
|
|
enable_cache_runtime_context = true;
|
|
|
|
|
if (!all_kernels_must_compute_runtime_shape &&
|
|
|
|
|
HasAttr(kAllKernelsMustComputeRuntimeShape))
|
|
|
|
|
all_kernels_must_compute_runtime_shape = true;
|
|
|
|
|
if (!enable_cache_runtime_context) {
|
|
|
|
|
RuntimeContext ctx(Inputs(), Outputs(), scope);
|
|
|
|
|
RunImpl(scope, place, &ctx);
|
|
|
|
|
} else {
|
|
|
|
@ -945,7 +952,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
|
|
|
|
|
if (!all_kernels_must_compute_runtime_shape) {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|