|
|
|
@ -77,8 +77,7 @@ class MKLDNNActivationGradKernel
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
mkldnn::algorithm algorithm, const T alpha = 0,
|
|
|
|
|
const T beta = 0) {
|
|
|
|
|
mkldnn::algorithm algorithm) {
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use CPUPlace.");
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
@ -90,6 +89,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
const T *x_data = x->data<T>();
|
|
|
|
|
T *y_data = y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
|
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
|
|
|
|
|
"Input dim must be with 2, 3 or 4");
|
|
|
|
@ -101,10 +103,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
|
|
|
|
|
bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
|
|
|
|
|
// with alpha, beta
|
|
|
|
|
std::string key = platform::MKLDNNHandler::GetHash(
|
|
|
|
|
src_tz, std::to_string(algorithm) + ctx.op().Output("Out"));
|
|
|
|
|
src_tz, std::to_string(algorithm) + std::to_string(alpha) +
|
|
|
|
|
std::to_string(beta) + ctx.op().Input("X"));
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Make it Thread safe
|
|
|
|
|
// save input data and layout to be referred in backward path
|
|
|
|
@ -153,8 +154,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
mkldnn::algorithm algorithm, const T alpha = 0,
|
|
|
|
|
const T beta = 0) {
|
|
|
|
|
mkldnn::algorithm algorithm) {
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto &mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
@ -164,6 +164,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
const T *diff_y_data = diff_y->data<T>();
|
|
|
|
|
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
|
|
|
|
|
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
|
|
|
|
|
|
|
|
|
|
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
|
|
|
|
|
|
|
|
|
|
auto diff_y_format =
|
|
|
|
@ -173,7 +176,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
|
|
|
|
|
|
|
|
|
|
std::string key = platform::MKLDNNHandler::GetHash(
|
|
|
|
|
diff_dst_tz, std::to_string(algorithm) + ctx.op().Input("Out"));
|
|
|
|
|
diff_dst_tz, std::to_string(algorithm) + std::to_string(alpha) +
|
|
|
|
|
std::to_string(beta) + ctx.op().Input("X"));
|
|
|
|
|
|
|
|
|
|
const std::string key_src_data = key + "@eltwise_fwd_src_data";
|
|
|
|
|
const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
|
|
|
|
@ -273,10 +277,11 @@ namespace ops = paddle::operators;
|
|
|
|
|
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
|
|
|
|
|
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
|
|
|
|
|
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
|
|
|
|
|
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
|
|
|
|
|
|
|
|
|
|
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
|
|
|
|
|