|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/detail/safe_ref.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -26,12 +27,16 @@ class ActivationKernel
|
|
|
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* Out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(*X);
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(*Out);
|
|
|
|
|
auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
|
|
|
|
|
"Cannot get input tensor X, variable name = %s",
|
|
|
|
|
context.op().Input("X"));
|
|
|
|
|
|
|
|
|
|
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
|
|
|
|
|
"Cannot get output tensor Out, variable name = %s",
|
|
|
|
|
context.op().Output("Out"));
|
|
|
|
|
Out.mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(X);
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(Out);
|
|
|
|
|
auto* place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
|