|
|
|
@ -111,7 +111,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
in_g->mutable_data<T>(context.GetPlace());
|
|
|
|
|
if (strategy == LAST || strategy == FIRST) {
|
|
|
|
|
// set X@Grad be zero at first when strategy is LAST/FIRST
|
|
|
|
|
math::SetConstant<Place, T>(context.device_context(), in_g, 0);
|
|
|
|
|
math::SetConstant<Place, T> functor;
|
|
|
|
|
functor(context.device_context(), in_g, 0);
|
|
|
|
|
}
|
|
|
|
|
auto place = context.GetEigenDevice<Place>();
|
|
|
|
|
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
|
|
|
|
|