|
|
|
@ -276,7 +276,8 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
|
|
|
|
|
ctx, paddle::operators::math::MulFunctor<T>(),
|
|
|
|
|
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("%s has not been implemented.", funcs_str);
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"%s has not been implemented.", funcs_str));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -374,7 +375,8 @@ static void RunGradFunctors(
|
|
|
|
|
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
|
|
|
|
|
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("%s has not been implemented.", funcs_str);
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"%s has not been implemented.", funcs_str));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -386,16 +388,21 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
|
|
|
|
|
"X", "FusedElemwiseActivation");
|
|
|
|
|
auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
|
|
|
|
|
"Y", "FusedElemwiseActivation");
|
|
|
|
|
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output(Out) should not be empty"));
|
|
|
|
|
auto output = ctx.Output<framework::Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
std::vector<framework::Tensor *> outputs;
|
|
|
|
|
outputs.emplace_back(output);
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("save_intermediate_out")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
|
|
|
|
|
"The save_intermediate_out is enable, so the "
|
|
|
|
|
"IntermediateOut should not be empty.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The save_intermediate_out is enable, so the "
|
|
|
|
|
"IntermediateOut should not be empty."));
|
|
|
|
|
|
|
|
|
|
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
|
|
|
|
|
outputs.emplace_back(intermediate_out);
|
|
|
|
|
} else {
|
|
|
|
@ -411,13 +418,18 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto in_y = ctx.Input<framework::Tensor>("Y");
|
|
|
|
|
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) should not be nullptr."));
|
|
|
|
|
auto in_out = ctx.Input<framework::Tensor>("Out");
|
|
|
|
|
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
in_out, nullptr,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
|
|
|
|
|
auto in_out_grad =
|
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE(in_out_grad != nullptr,
|
|
|
|
|
"Input(Out@Grad) should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_NE(in_out_grad, nullptr,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@Grad) should not be nullptr."));
|
|
|
|
|
|
|
|
|
|
framework::Tensor *in_x =
|
|
|
|
|
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
|
|
|
|
|
framework::Tensor *x_grad =
|
|
|
|
@ -437,24 +449,28 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// recompute.
|
|
|
|
|
in_intermediate_out = const_cast<framework::Tensor *>(
|
|
|
|
|
ctx.Input<framework::Tensor>("IntermediateOut"));
|
|
|
|
|
PADDLE_ENFORCE(in_intermediate_out != nullptr,
|
|
|
|
|
"The option of 'save_intermediate_out' is opened, "
|
|
|
|
|
"so the number of 'Out' should be two.");
|
|
|
|
|
PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The option of 'save_intermediate_out' is opened,"
|
|
|
|
|
" so the number of 'Out' should be two."));
|
|
|
|
|
} else {
|
|
|
|
|
if (!InputXCanBeAbsent(functor_list)) {
|
|
|
|
|
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) should not be null."));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get in_x
|
|
|
|
|
if (ctx.HasInput("X")) {
|
|
|
|
|
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) should not be null."));
|
|
|
|
|
} else {
|
|
|
|
|
// If functor_list contains elementwise_add, the backward doesn't use
|
|
|
|
|
// in_x, in_y and in_out.
|
|
|
|
|
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
|
|
|
|
|
"Only when the compoundfunctor contains "
|
|
|
|
|
"elementwise_add_grad, the 'X' could be absent.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only when the compoundfunctor contains "
|
|
|
|
|
"elementwise_add_grad, the 'X' could be absent."));
|
|
|
|
|
in_x = const_cast<framework::Tensor *>(in_out_grad);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|