modify datanorm op test=develop (#23030)

revert-22778-infer_var_type
yaoxuefeng 5 years ago committed by GitHub
parent 3e1676fa9a
commit 5b69242fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,7 +53,9 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad" // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0
};
namespace paddle {

File diff suppressed because it is too large Load Diff

@ -3157,7 +3157,8 @@ def data_norm(input,
do_model_average_for_mean_and_var=True,
slot_dim=-1,
sync_stats=False,
summary_decay_rate=0.9999999):
summary_decay_rate=0.9999999,
enable_scale_and_shift=False):
"""
**Data Normalization Layer**
@ -3206,6 +3207,7 @@ def data_norm(input,
sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the
summary messages.
summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary.
enable_scale_and_shift(bool, Default False): do scale&shift after normalization.
Returns:
Variable: A tensor variable which is the result after applying data normalization on the input.
@ -3236,12 +3238,35 @@ def data_norm(input,
batch_size_default = 1e4
batch_sum_default = 0.0
batch_square_sum_default = 1e4
scale_w_default = 1.0
bias_default = 0.0
if param_attr and isinstance(param_attr, dict):
batch_size_default = param_attr.get("batch_size", 1e4)
batch_sum_default = param_attr.get("batch_sum", 0.0)
batch_square_sum_default = param_attr.get("batch_square", 1e4)
if enable_scale_and_shift:
scale_w_default = param_attr.get("scale_w", 1.0)
bias_default = param_attr.get("bias", 0.0)
# create scale and shift(bias) when enable_scale_and_shift is True
if name == None:
name = "dn"
if enable_scale_and_shift:
scale_w = helper.create_parameter(
attr=ParamAttr(
name=name + '.scale_w',
initializer=Constant(value=float(scale_w_default)),
trainable=True),
shape=param_shape,
dtype=input.dtype)
bias = helper.create_parameter(
attr=ParamAttr(
name=name + '.bias',
initializer=Constant(value=float(bias_default)),
trainable=True),
shape=param_shape,
dtype=input.dtype)
# create parameter
batch_size = helper.create_parameter(
attr=ParamAttr(
@ -3272,14 +3297,18 @@ def data_norm(input,
data_norm_out = input if in_place else helper.create_variable(dtype=dtype)
helper.append_op(
type="data_norm",
inputs = {
"X": input,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
},
}
if enable_scale_and_shift:
inputs["scale_w"] = scale_w
inputs["bias"] = bias
helper.append_op(
type="data_norm",
inputs=inputs,
outputs={
"Y": data_norm_out,
"Means": means,
@ -3292,7 +3321,8 @@ def data_norm(input,
"epsilon": epsilon,
"slot_dim": slot_dim,
"sync_stats": sync_stats,
"summary_decay_rate": summary_decay_rate
"summary_decay_rate": summary_decay_rate,
"enable_scale_and_shift": enable_scale_and_shift
})
return helper.append_activation(data_norm_out)

@ -24,6 +24,7 @@ import paddle.fluid.layers as layers
import os
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
from paddle.fluid import Program, program_guard
def _reference_testing(x, batch_size, batch_sum, batch_square_sum, slot_dim=-1):
@ -72,7 +73,13 @@ class TestDataNormOpInference(unittest.TestCase):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(self, place, data_layout, dtype, shape, slot_dim=-1):
def check_with_place(self,
place,
data_layout,
dtype,
shape,
slot_dim=-1,
enable_scale_and_shift=False):
"""
do forward and check
@ -82,7 +89,7 @@ class TestDataNormOpInference(unittest.TestCase):
dtype(dtype): np.float32
shape(list): input shape
slot_dim(int): dimension of one slot. Refer to data_norm api.
enable_scale_and_shift(bool): if enable scale and shift after normalization.
"""
epsilon = 0.00001
@ -127,6 +134,7 @@ class TestDataNormOpInference(unittest.TestCase):
mean_tensor = create_or_get_tensor(scope, "mean", None, place)
scales_tensor = create_or_get_tensor(scope, "scales", None, place)
if not enable_scale_and_shift:
data_norm_op = Operator(
"data_norm",
# inputs
@ -141,7 +149,34 @@ class TestDataNormOpInference(unittest.TestCase):
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim)
slot_dim=slot_dim,
enable_scale_and_shift=False)
else:
scale_w = np.ones(scale_shape).astype(np.float32)
bias = np.zeros(scale_shape).astype(np.float32)
scale_w_tensor = create_or_get_tensor(
scope, "scale_w",
OpTest.np_dtype_to_fluid_dtype(scale_w), place)
bias_tensor = create_or_get_tensor(
scope, "bias", OpTest.np_dtype_to_fluid_dtype(bias), place)
data_norm_op = Operator(
"data_norm",
# inputs
X="x_val",
BatchSize="batch_size",
BatchSum="batch_sum",
BatchSquareSum="batch_square_sum",
scale_w="scale_w",
bias="bias",
# outputs
Y="y_out",
Means="mean",
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim,
enable_scale_and_shift=True)
data_norm_op.run(scope, place)
@ -162,11 +197,13 @@ class TestDataNormOpInference(unittest.TestCase):
for place in places:
for data_format in ["NCHW", "NHWC"]:
for slot_dim in [-1, 1]:
for enable_scale_and_shift in [False, True]:
self.check_with_place(
place,
data_format,
self.dtype, [2, 3],
slot_dim=slot_dim)
slot_dim=slot_dim,
enable_scale_and_shift=enable_scale_and_shift)
class TestDataNormOp(OpTest):
@ -220,6 +257,130 @@ class TestDataNormOp(OpTest):
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithEnableScaleAndShift(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = -1
enable_scale_and_shitf = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
x_val = np.random.uniform(-1, 1, x_shape).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
scale_w = np.ones(scale_shape).astype(tp)
bias = np.zeros(scale_shape).astype(tp)
y = np.array(x_val)
mean = np.zeros(x_shape).astype(tp)
scale = np.ones(x_shape).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum,
"scale_w": scale_w,
"bias": bias
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {
"epsilon": epsilon,
"use_mkldnn": self.use_mkldnn,
"slot_dim": slot_dim,
"enable_scale_and_shift": True
}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithEnableScaleAndShift_1(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = 1
enable_scale_and_shitf = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
x_val = np.random.uniform(-1, 1, x_shape).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
scale_w = np.ones(scale_shape).astype(tp)
bias = np.zeros(scale_shape).astype(tp)
y = np.array(x_val)
mean = np.zeros(x_shape).astype(tp)
scale = np.ones(x_shape).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum,
"scale_w": scale_w,
"bias": bias
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {
"epsilon": epsilon,
"use_mkldnn": self.use_mkldnn,
"slot_dim": slot_dim,
"enable_scale_and_shift": True
}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithSlotDim(OpTest):
"""
test class for data norm op
@ -399,5 +560,14 @@ class TestDataNormOpWithSyncStats(unittest.TestCase):
os.remove(f)
class TestDataNormOpErrorr(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x2 = fluid.layers.data(name='x2', shape=[3, 4], dtype="int32")
#self.assertRaises(TypeError, fluid.data_norm, x2)
fluid.layers.data_norm(
input=x2, param_attr={}, enable_scale_and_shift=True)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save