|
|
|
|
@ -874,17 +874,23 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
|
|
|
|
|
return kernel_configs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
RuntimeContext* OperatorWithKernel::GetRuntimeContext(
|
|
|
|
|
const Scope& scope) const {
|
|
|
|
|
if (!HasAttr(kEnableRuntimeContext)) {
|
|
|
|
|
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
|
|
|
|
|
return new RuntimeContext(Inputs(), Outputs(), scope);
|
|
|
|
|
} else {
|
|
|
|
|
const Scope* cur_scope = &scope;
|
|
|
|
|
if (!runtime_ctx_ || pre_scope_ != cur_scope) {
|
|
|
|
|
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
|
|
|
|
|
pre_scope_ = cur_scope;
|
|
|
|
|
}
|
|
|
|
|
return runtime_ctx_.get();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
auto runtime_ctx = GetRuntimeContext(scope);
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
@ -899,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(
|
|
|
|
|
ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx_, nullptr));
|
|
|
|
|
ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx, nullptr));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
@ -923,8 +929,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
|
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
|
auto* transfer_scope = PrepareData(
|
|
|
|
|
scope, expected_kernel_key, &transfered_inplace_vars, runtime_ctx_.get());
|
|
|
|
|
auto* transfer_scope = PrepareData(scope, expected_kernel_key,
|
|
|
|
|
&transfered_inplace_vars, runtime_ctx);
|
|
|
|
|
|
|
|
|
|
// exec scope is the scope that kernel actually executed on.
|
|
|
|
|
const Scope& exec_scope =
|
|
|
|
|
@ -935,13 +941,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx_);
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
|
|
|
|
|
// not Scope. Imperative mode only pass inputs and get outputs.
|
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx,
|
|
|
|
|
*runtime_ctx_, kernel_configs));
|
|
|
|
|
*runtime_ctx, kernel_configs));
|
|
|
|
|
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
|
// there is inplace variable has been transfered.
|
|
|
|
|
|