@ -759,6 +759,11 @@ def sequence_pool(input, pool_type, **kwargs):
"MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()})
# when pool_type is max, variable max_index is initialized,
# so we stop the gradient explicitly here
if pool_type == 'max':
max_index.stop_gradient = True
return pool_out