|
|
|
@ -713,7 +713,7 @@ def max_pool2d(x,
|
|
|
|
|
'data_format', data_format)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "max_pool2d"
|
|
|
|
|
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "pool2d"
|
|
|
|
|
helper = LayerHelper(op_type, **locals())
|
|
|
|
|
dtype = helper.input_dtype()
|
|
|
|
|
pool_out = helper.create_variable_for_type_inference(dtype)
|
|
|
|
@ -839,7 +839,7 @@ def max_pool3d(x,
|
|
|
|
|
'data_format', data_format)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "max_pool3d"
|
|
|
|
|
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "pool3d"
|
|
|
|
|
helper = LayerHelper(op_type, **locals())
|
|
|
|
|
dtype = helper.input_dtype()
|
|
|
|
|
pool_out = helper.create_variable_for_type_inference(dtype)
|
|
|
|
|