|
|
|
@ -42,9 +42,9 @@ class SigmoidKernel : public framework::OpKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class SigmoidGradKernel : public OpKernel {
|
|
|
|
|
class SigmoidGradKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const ExecutionContext& context) const override {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto Y_t = context.Input<Tensor>("Y");
|
|
|
|
|
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|