|
|
|
@ -884,6 +884,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
// result of HasAttr.
|
|
|
|
|
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
|
|
|
|
|
enable_cache_runtime_context = true;
|
|
|
|
|
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
|
|
|
|
|
enable_cache_expected_kernel = true;
|
|
|
|
|
if (!all_kernels_must_compute_runtime_shape &&
|
|
|
|
|
HasAttr(kAllKernelsMustComputeRuntimeShape))
|
|
|
|
|
all_kernels_must_compute_runtime_shape = true;
|
|
|
|
@ -906,7 +908,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) {
|
|
|
|
|
if (!enable_cache_expected_kernel || !kernel_type_) {
|
|
|
|
|
ChooseKernel(*runtime_ctx, scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|