|
|
|
@ -10,6 +10,9 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
@ -25,6 +28,16 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
/* Use ugly global variable, for the using in python layer side
|
|
|
|
|
Please refer to the layer_helper.py and get the details.
|
|
|
|
|
*/
|
|
|
|
|
static std::unordered_set<std::string> InplaceOpSet = {
|
|
|
|
|
"sigmoid", "exp", "relu", "tanh", "sqrt", "ceil",
|
|
|
|
|
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid",
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static bool IsInplace(std::string op) { return InplaceOpSet.count(op); }
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class ActivationKernel
|
|
|
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
@ -60,7 +73,6 @@ class ActivationGradKernel
|
|
|
|
|
public:
|
|
|
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* Out = context.Input<framework::Tensor>("Out");
|
|
|
|
|
auto* dOut =
|
|
|
|
|
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
@ -68,7 +80,6 @@ class ActivationGradKernel
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(*dOut);
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(*X);
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(*Out);
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(*dX);
|
|
|
|
|
auto* place =
|
|
|
|
@ -78,7 +89,16 @@ class ActivationGradKernel
|
|
|
|
|
for (auto& attr : attrs) {
|
|
|
|
|
*attr.second = context.Attr<float>(attr.first);
|
|
|
|
|
}
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
bool inplace = functor.Inplace();
|
|
|
|
|
if (!inplace) {
|
|
|
|
|
auto* X = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(*X);
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(10) << " Inplace activation ";
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(*dX);
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -89,6 +109,14 @@ struct BaseActivationFunctor {
|
|
|
|
|
using AttrPair = std::vector<std::pair<const char*, float*>>;
|
|
|
|
|
|
|
|
|
|
AttrPair GetAttrs() { return AttrPair(); }
|
|
|
|
|
|
|
|
|
|
/* NOTE(*): Output reuse X memory if X is not dependented by its Gradient.
|
|
|
|
|
For example, sigmoid op's gradient didn't involve x, so its output can
|
|
|
|
|
reuse
|
|
|
|
|
input memory. But abs op's gradient use x, it can not be inplaced.
|
|
|
|
|
gradient did use x.
|
|
|
|
|
*/
|
|
|
|
|
bool Inplace() const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// sigmoid(x) = 1 / (1 + exp(-x))
|
|
|
|
@ -102,6 +130,7 @@ struct SigmoidFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("sigmoid"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
@ -156,6 +185,7 @@ struct ExpFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ExpGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("exp"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
@ -174,10 +204,11 @@ struct ReluFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("relu"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
|
|
|
|
|
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -192,6 +223,7 @@ struct TanhFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("tanh"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
@ -297,6 +329,7 @@ struct SqrtFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("sqrt"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
@ -316,10 +349,11 @@ struct CeilFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("ceil"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
dx.device(d) = static_cast<T>(0) / x;
|
|
|
|
|
dx.device(d) = static_cast<T>(0) / out;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -432,6 +466,7 @@ struct ReciprocalFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("reciprocal"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
@ -531,12 +566,14 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
bool Inplace() const { return IsInplace("relu6"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
dx.device(d) = dout *
|
|
|
|
|
((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
|
|
|
|
|
.template cast<T>();
|
|
|
|
|
dx.device(d) =
|
|
|
|
|
dout *
|
|
|
|
|
((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
|
|
|
|
|
.template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -611,11 +648,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
bool Inplace() const { return IsInplace("softrelu"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
auto tmp = static_cast<T>(threshold);
|
|
|
|
|
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
|
|
|
|
|
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
|
|
|
|
|
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -791,7 +829,7 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"slope", &slope}, {"offset", &offset}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Inplace() { return IsInplace("hard_sigmoid"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|