|
|
|
@ -20,6 +20,8 @@
|
|
|
|
|
#include "paddle/fluid/imperative/infer_shape_context.h"
|
|
|
|
|
#include "paddle/fluid/imperative/infer_var_type_context.h"
|
|
|
|
|
|
|
|
|
|
DECLARE_bool(use_mkldnn);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
@ -91,8 +93,10 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
|
|
|
|
|
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
|
|
|
|
|
// GetKernelType functions, so we need to copy the attributes there.
|
|
|
|
|
// Const qualifier of Attrs had to be discarded to overwrite it.
|
|
|
|
|
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
|
|
|
|
|
mutable_op_attrs = attrs;
|
|
|
|
|
if (FLAGS_use_mkldnn) {
|
|
|
|
|
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
|
|
|
|
|
mutable_op_attrs = attrs;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
auto expected_kernel_key =
|
|
|
|
|
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
|
|
|
|
|