|
|
|
@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
// result of HasAttr.
|
|
|
|
|
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
|
|
|
|
|
enable_cache_runtime_context = true;
|
|
|
|
|
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
|
|
|
|
|
enable_cache_expected_kernel = true;
|
|
|
|
|
if (!all_kernels_must_compute_runtime_shape &&
|
|
|
|
|
HasAttr(kAllKernelsMustComputeRuntimeShape))
|
|
|
|
|
all_kernels_must_compute_runtime_shape = true;
|
|
|
|
@ -894,9 +892,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
RunImpl(scope, place, &ctx);
|
|
|
|
|
} else {
|
|
|
|
|
const Scope* cur_scope = &scope;
|
|
|
|
|
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
|
|
|
|
|
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
|
|
|
|
|
pre_scope_ = cur_scope;
|
|
|
|
|
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(cache_update_mutex_);
|
|
|
|
|
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
|
|
|
|
|
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
|
|
|
|
|
pre_scope_ = cur_scope;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
RunImpl(scope, place, runtime_ctx_.get());
|
|
|
|
|
}
|
|
|
|
@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
if (!enable_cache_expected_kernel || !kernel_type_) {
|
|
|
|
|
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
|
|
|
|
|
ChooseKernel(*runtime_ctx, scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_type_.reset(new OpKernelType(expected_kernel_key));
|
|
|
|
|
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
|
|
|
|
|
std::lock_guard<std::mutex> lock(cache_update_mutex_);
|
|
|
|
|
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
|
|
|
|
|
kernel_type_.reset(new OpKernelType(expected_kernel_key));
|
|
|
|
|
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
|