del_some_in_makelist
Yu Yang 7 years ago committed by GitHub
parent 9592468609
commit 9f44af9d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -762,7 +762,7 @@ def sequence_conv(input,
helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype()
filter_shape = [filter_size * input.shape[1], num_filters]
filter = helper.create_parameter(
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype)
@ -770,7 +770,7 @@ def sequence_conv(input,
type='sequence_conv',
inputs={
'X': [input],
'Filter': [filter],
'Filter': [filter_param],
},
outputs={"Out": pre_bias},
attrs={
@ -785,7 +785,7 @@ def sequence_conv(input,
def conv2d(input,
num_filters,
filter_size,
stride=[1, 1],
stride=None,
padding=None,
groups=None,
param_attr=None,
@ -802,6 +802,8 @@ def conv2d(input,
conv-2d output, if mentioned in the input parameters.
"""
if stride is None:
stride = [1, 1]
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
@ -827,7 +829,7 @@ def conv2d(input,
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return Normal(0.0, std, 0)
filter = helper.create_parameter(
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
@ -839,7 +841,7 @@ def conv2d(input,
type='conv2d_cudnn',
inputs={
'Input': input,
'Filter': filter,
'Filter': filter_param,
},
outputs={"Output": pre_bias},
attrs={'strides': stride,
@ -875,8 +877,8 @@ def sequence_pool(input, pool_type, **kwargs):
def pool2d(input,
pool_size,
pool_type,
pool_stride=[1, 1],
pool_padding=[0, 0],
pool_stride=None,
pool_padding=None,
global_pooling=False,
main_program=None,
startup_program=None):
@ -884,6 +886,10 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",

Loading…
Cancel
Save