Merge pull request #5555 from emailweixu/fix_sequence_pool

Fix sequence_pool_op in debug mode
mobile_baidu
emailweixu 7 years ago committed by GitHub
commit 05c0908492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e;
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}

Loading…
Cancel
Save