|
|
|
@ -25,6 +25,59 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
class PreparedOp {
|
|
|
|
|
public:
|
|
|
|
|
PreparedOp(const framework::OperatorBase& op,
|
|
|
|
|
const framework::RuntimeContext& ctx,
|
|
|
|
|
framework::OperatorWithKernel::OpKernelFunc func,
|
|
|
|
|
platform::DeviceContext* dev_ctx)
|
|
|
|
|
: op(op), ctx(ctx), func(func), dev_ctx(dev_ctx) {}
|
|
|
|
|
|
|
|
|
|
static PreparedOp Prepare(const framework::RuntimeContext& ctx,
|
|
|
|
|
const framework::OperatorWithKernel& op,
|
|
|
|
|
const platform::Place& place) {
|
|
|
|
|
framework::Scope dummy_scope;
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = op.AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(op.Type());
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"There are no kernels which are registered in the %s operator.",
|
|
|
|
|
op.Type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key = op.GetExpectedKernelType(
|
|
|
|
|
framework::ExecutionContext(op, dummy_scope, *dev_ctx, ctx));
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
|
|
|
|
|
if (kernel_iter == kernels.end() &&
|
|
|
|
|
expected_kernel_key.library_type_ == framework::LibraryType::kMKLDNN) {
|
|
|
|
|
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
|
|
|
|
|
expected_kernel_key.library_type_ = framework::LibraryType::kPlain;
|
|
|
|
|
expected_kernel_key.data_layout_ = framework::DataLayout::kAnyLayout;
|
|
|
|
|
kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op %s does not have kernel for %s", op.Type(),
|
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
|
}
|
|
|
|
|
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const framework::OperatorBase& op;
|
|
|
|
|
const framework::RuntimeContext& ctx;
|
|
|
|
|
framework::OperatorWithKernel::OpKernelFunc func;
|
|
|
|
|
platform::DeviceContext* dev_ctx;
|
|
|
|
|
};
|
|
|
|
|
class OpBase;
|
|
|
|
|
|
|
|
|
|
class VarBase {
|
|
|
|
@ -62,30 +115,22 @@ class VarBase {
|
|
|
|
|
|
|
|
|
|
class OpBase {
|
|
|
|
|
public:
|
|
|
|
|
OpBase()
|
|
|
|
|
: pre_ops_(new std::map<std::string, std::vector<OpBase*>>()),
|
|
|
|
|
pre_ops_out_idx_(new std::map<std::string, std::vector<int>>()),
|
|
|
|
|
op_desc_(nullptr),
|
|
|
|
|
grad_op_desc_(nullptr) {}
|
|
|
|
|
OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
|
|
|
|
|
|
|
|
|
|
virtual ~OpBase() {
|
|
|
|
|
delete pre_ops_;
|
|
|
|
|
delete pre_ops_out_idx_;
|
|
|
|
|
|
|
|
|
|
if (grad_op_desc_) delete grad_op_desc_;
|
|
|
|
|
if (grad_to_var_) delete grad_to_var_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
|
|
|
|
|
|
|
|
|
|
framework::OpDesc* op_desc_;
|
|
|
|
|
framework::OpDesc* grad_op_desc_;
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> input_vars_;
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> output_vars_;
|
|
|
|
|
std::map<std::string, std::vector<OpBase*>>* pre_ops_;
|
|
|
|
|
std::map<std::string, std::vector<int>>* pre_ops_out_idx_;
|
|
|
|
|
framework::OpDesc* op_desc_;
|
|
|
|
|
std::map<std::string, std::vector<OpBase*>> pre_ops_;
|
|
|
|
|
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
|
|
|
|
|
|
|
|
|
|
framework::OpDesc* grad_op_desc_;
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var_;
|
|
|
|
|
std::map<std::string, std::vector<framework::Variable*>> grad_input_vars_;
|
|
|
|
|
std::map<std::string, std::vector<framework::Variable*>> grad_output_vars_;
|
|
|
|
|
framework::BlockDesc* block_;
|
|
|
|
|