|
|
|
@ -880,50 +880,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"There are no kernels which are registered in the %s operator.", type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(
|
|
|
|
|
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
|
|
|
|
|
if (kernel_iter == kernels.end() &&
|
|
|
|
|
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
|
|
|
|
|
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
|
|
|
|
|
expected_kernel_key.library_type_ = LibraryType::kPlain;
|
|
|
|
|
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
|
|
|
|
|
kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", type_,
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
if (!kernel_type_) {
|
|
|
|
|
ChooseKernel(ctx, scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<KernelConfig>* kernel_configs =
|
|
|
|
|
GetKernelConfig(expected_kernel_key);
|
|
|
|
|
std::vector<KernelConfig>* kernel_configs = GetKernelConfig(*kernel_type_);
|
|
|
|
|
|
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
|
auto* transfer_scope =
|
|
|
|
|
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
|
|
|
|
|
PrepareData(scope, *kernel_type_, &transfered_inplace_vars, &ctx);
|
|
|
|
|
|
|
|
|
|
// exec scope is the scope that kernel actually executed on.
|
|
|
|
|
const Scope& exec_scope =
|
|
|
|
|
(transfer_scope == nullptr ? scope : *transfer_scope);
|
|
|
|
|
|
|
|
|
|
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
|
|
|
|
|
dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
|
if (!(kernel_type_->place_ == dev_ctx->GetPlace())) {
|
|
|
|
|
dev_ctx = pool.Get(kernel_type_->place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) {
|
|
|
|
@ -932,7 +905,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
|
|
|
|
|
// not Scope. Imperative mode only pass inputs and get outputs.
|
|
|
|
|
kernel_iter->second(
|
|
|
|
|
(*kernel_func_)(
|
|
|
|
|
ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs));
|
|
|
|
|
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
@ -959,6 +932,46 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
|
|
|
|
|
const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"There are no kernels which are registered in the %s operator.", type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(
|
|
|
|
|
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
|
|
|
|
|
if (kernel_iter == kernels.end() &&
|
|
|
|
|
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
|
|
|
|
|
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
|
|
|
|
|
expected_kernel_key.library_type_ = LibraryType::kPlain;
|
|
|
|
|
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
|
|
|
|
|
kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", type_,
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_type_.reset(new OpKernelType(expected_kernel_key));
|
|
|
|
|
kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
|
const Scope& scope, const std::vector<std::string>& inplace_vars,
|
|
|
|
|
const Scope& transfer_scope) const {
|
|
|
|
|