|
|
@ -20,7 +20,7 @@ namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
class ActivationKernel : public framework::OpKernel {
|
|
|
|
class ActivationKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
template <typename Place, typename T, typename Functor>
|
|
|
|
class ActivationGradKernel : public framework::OpKernel {
|
|
|
|
class ActivationGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -202,7 +202,7 @@ struct SquareGradFunctor {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class BReluKernel : public framework::OpKernel {
|
|
|
|
class BReluKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class BReluGradKernel : public framework::OpKernel {
|
|
|
|
class BReluGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class SoftReluKernel : public framework::OpKernel {
|
|
|
|
class SoftReluKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class SoftReluGradKernel : public framework::OpKernel {
|
|
|
|
class SoftReluGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class PowKernel : public framework::OpKernel {
|
|
|
|
class PowKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class PowGradKernel : public framework::OpKernel {
|
|
|
|
class PowGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class STanhKernel : public framework::OpKernel {
|
|
|
|
class STanhKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
template <typename Place, typename T, typename AttrType = T>
|
|
|
|
class STanhGradKernel : public framework::OpKernel {
|
|
|
|
class STanhGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|