|
|
@ -759,6 +759,11 @@ def sequence_pool(input, pool_type, **kwargs):
|
|
|
|
"MaxIndex": max_index},
|
|
|
|
"MaxIndex": max_index},
|
|
|
|
attrs={"pooltype": pool_type.upper()})
|
|
|
|
attrs={"pooltype": pool_type.upper()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# when pool_type is max, variable max_index is initialized,
|
|
|
|
|
|
|
|
# so we stop the gradient explicitly here
|
|
|
|
|
|
|
|
if pool_type == 'max':
|
|
|
|
|
|
|
|
max_index.stop_gradient = True
|
|
|
|
|
|
|
|
|
|
|
|
return pool_out
|
|
|
|
return pool_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|