Fix img_pool_layer bug.

enforce_failed
wanghaoshuang 8 years ago
parent 06fad3fe9d
commit 818a64f41f

@ -2607,15 +2607,15 @@ def img_pool_layer(input,
assert input.num_filters is not None
num_channels = input.num_filters
assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
CudnnMaxPooling], \
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
CudnnMaxPooling], \
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
type_name = pool_type.name + '-projection' \
if (
isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \

Loading…
Cancel
Save