Big data op_test benchmark, for checking output consistent in different runs. (#10646)
* "init benchmark ops" * "untrack outputs" * "delete some usused code" * "benchmark" * "fix ci" * "fix op test" * "fix uint16 missing" * "fix ci" * "follow comments" * "fix ci" * "follow comments" * "conficts. merge develop branch" * repick * "merge develop branch"wangkuiyi-patch-1
parent
3ff9ba0e6b
commit
f7c96f079b
@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import time
|
||||
import itertools
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class BenchmarkSuite(OpTest):
|
||||
def timeit_function(self, callback, iters, *args, **kwargs):
|
||||
assert iters != 0, "Iters should >= 1"
|
||||
start = time.time()
|
||||
for i in range(iters):
|
||||
callback(*args, **kwargs)
|
||||
elapse = time.time() - start
|
||||
return elapse / iters
|
||||
|
||||
def _assert_cpu_gpu_same(self, cpu_outs, gpu_outs, fetch_list, atol):
|
||||
for item_cpu_out, item_gpu_out, variable in zip(cpu_outs, gpu_outs,
|
||||
fetch_list):
|
||||
# the cpu version is baseline, expect gpu version keep same with cpu version.
|
||||
expect = item_cpu_out
|
||||
expect_t = np.array(item_cpu_out)
|
||||
actual = item_gpu_out
|
||||
actual_t = np.array(item_gpu_out)
|
||||
var_name = variable if isinstance(variable,
|
||||
basestring) else variable.name
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
actual_t, expect_t, atol=atol),
|
||||
"Output (" + var_name + ") has diff" + str(actual_t) + "\n" +
|
||||
str(expect_t))
|
||||
self.assertListEqual(actual.lod(),
|
||||
expect.lod(),
|
||||
"Output (" + var_name + ") has different lod")
|
||||
|
||||
def _get_input_names(self):
|
||||
inputs = []
|
||||
for name, value in self.inputs.iteritems():
|
||||
if isinstance(value, list):
|
||||
inputs.extend([sub_name for sub_name, _ in value])
|
||||
inputs.append(name)
|
||||
return inputs
|
||||
|
||||
def _get_output_names(self):
|
||||
outputs = []
|
||||
for var_name, var in self.outputs.iteritems():
|
||||
if isinstance(var, list):
|
||||
for sub_var_name, sub_var in var:
|
||||
outputs.append(sub_var_name)
|
||||
else:
|
||||
outputs.append(var_name)
|
||||
if len(outputs) == 0:
|
||||
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
|
||||
outputs.append(str(out_name))
|
||||
return outputs
|
||||
|
||||
def check_output_stability(self, atol=1e-8):
|
||||
places = self._get_places()
|
||||
if len(places) < 2:
|
||||
return
|
||||
cpu_outs, fetch_list = self._calc_output(places[0])
|
||||
gpu_outs, _ = self._calc_output(places[1])
|
||||
self._assert_cpu_gpu_same(cpu_outs, gpu_outs, fetch_list, atol)
|
||||
|
||||
def timeit_output_with_place(self, place, iters):
|
||||
return self.timeit_function(self.calc_output, iters, place)
|
||||
|
||||
def timeit_output(self, iters=100):
|
||||
places = self._get_places()
|
||||
elapses = []
|
||||
for place in places:
|
||||
elapses.append(self.timeit_output_with_place(place, iters))
|
||||
for place, elapse in zip(places, elapses):
|
||||
print("One pass of ({2}_op) at {0} cost {1}".format(
|
||||
str(place), elapse, self.op_type))
|
||||
|
||||
def timeit_grad_with_place(self, place, iters=100):
|
||||
inputs_to_check = self._get_input_names()
|
||||
output_names = self._get_output_names()
|
||||
return self.timeit_function(
|
||||
self._get_gradient,
|
||||
iters,
|
||||
inputs_to_check,
|
||||
place,
|
||||
output_names,
|
||||
no_grad_set=None)
|
||||
|
||||
def timeit_grad(self, iters=100):
|
||||
places = self._get_places()
|
||||
elapses = []
|
||||
for place in places:
|
||||
elapses.append(self.timeit_grad_with_place(place, iters))
|
||||
for place, elapse in zip(places, elapses):
|
||||
print("One pass of ({2}_grad_op) at {0} cost {1}".format(
|
||||
str(place), elapse, self.op_type))
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from benchmark import BenchmarkSuite
|
||||
from op_test import OpTest
|
||||
|
||||
# This is a demo op test case for operator benchmarking and high resolution number stability alignment.
|
||||
|
||||
|
||||
class TestSumOp(BenchmarkSuite):
|
||||
def setUp(self):
|
||||
self.op_type = "sum"
|
||||
self.customize_testcase()
|
||||
self.customize_fetch_list()
|
||||
|
||||
def customize_fetch_list(self):
|
||||
"""
|
||||
customize fetch list, configure the wanted variables.
|
||||
>>> self.fetch_list = ["Out"]
|
||||
"""
|
||||
self.fetch_list = ["Out"]
|
||||
# pass
|
||||
|
||||
def customize_testcase(self):
|
||||
# a test case
|
||||
x0 = np.random.random((300, 400)).astype('float32')
|
||||
x1 = np.random.random((300, 400)).astype('float32')
|
||||
x2 = np.random.random((300, 400)).astype('float32')
|
||||
|
||||
# NOTE: if the output is empty, then it will autofilled by benchmarkSuite.
|
||||
# only the output dtype is used, the shape, lod and data is computed from input.
|
||||
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
|
||||
self.outputs = {"Out": x0 + x1 + x2}
|
||||
|
||||
def test_check_output(self):
|
||||
"""
|
||||
compare the output with customized output. In this case,
|
||||
you should set the correct output by hands.
|
||||
>>> self.outputs = {"Out": x0 + x1 + x2}
|
||||
"""
|
||||
self.check_output(atol=1e-8)
|
||||
|
||||
def test_output_stability(self):
|
||||
# compare the cpu gpu output in high resolution.
|
||||
self.check_output_stability()
|
||||
|
||||
def test_timeit_output(self):
|
||||
"""
|
||||
perf the op, time cost will be averged in iters.
|
||||
output example
|
||||
>>> One pass of (sum_op) at CPUPlace cost 0.000461330413818
|
||||
>>> One pass of (sum_op) at CUDAPlace(0) cost 0.000556070804596
|
||||
"""
|
||||
self.timeit_output(iters=100)
|
||||
|
||||
def test_timeit_grad(self):
|
||||
"""
|
||||
perf the op gradient, time cost will be averged in iters.
|
||||
output example
|
||||
>>> One pass of (sum_grad_op) at CPUPlace cost 0.00279935121536
|
||||
>>> One pass of (sum_grad_op) at CUDAPlace(0) cost 0.00500632047653
|
||||
"""
|
||||
self.timeit_grad(iters=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
|
||||
|
||||
def as_lodtensor(np_array, lod, place):
|
||||
tensor = core.LoDTensor()
|
||||
tensor.set(np_value, place)
|
||||
if lod is not None:
|
||||
tensor.set_lod(lod)
|
||||
return tensor
|
||||
|
||||
|
||||
def create_op(scope, op_type, inputs, outputs, attrs):
|
||||
kwargs = dict()
|
||||
|
||||
op_maker = core.op_proto_and_checker_maker
|
||||
op_role_attr_name = op_maker.kOpRoleAttrName()
|
||||
|
||||
if op_role_attr_name not in attrs:
|
||||
attrs[op_role_attr_name] = int(op_maker.OpRole.Forward)
|
||||
|
||||
def __create_var__(name, var_name):
|
||||
scope.var(var_name).get_tensor()
|
||||
kwargs[name].append(var_name)
|
||||
|
||||
for in_name, in_dup in Operator.get_op_inputs(op_type):
|
||||
if in_name in inputs:
|
||||
kwargs[in_name] = []
|
||||
if in_dup:
|
||||
sub_in = inputs[in_name]
|
||||
for item in sub_in:
|
||||
sub_in_name, _ = item[0], item[1]
|
||||
__create_var__(in_name, sub_in_name)
|
||||
else:
|
||||
__create_var__(in_name, in_name)
|
||||
|
||||
for out_name, out_dup in Operator.get_op_outputs(op_type):
|
||||
if out_name in outputs:
|
||||
kwargs[out_name] = []
|
||||
if out_dup:
|
||||
sub_out = outputs[out_name]
|
||||
for item in sub_out:
|
||||
sub_out_name, _ = item[0], item[1]
|
||||
__create_var__(out_name, sub_out_name)
|
||||
else:
|
||||
__create_var__(out_name, out_name)
|
||||
|
||||
for attr_name in Operator.get_op_attr_names(op_type):
|
||||
if attr_name in attrs:
|
||||
kwargs[attr_name] = attrs[attr_name]
|
||||
|
||||
return Operator(op_type, **kwargs)
|
||||
|
||||
|
||||
def set_input(scope, op, inputs, place):
|
||||
def __set_input__(var_name, var):
|
||||
if isinstance(var, tuple) or isinstance(var, np.ndarray):
|
||||
tensor = scope.find_var(var_name).get_tensor()
|
||||
if isinstance(var, tuple):
|
||||
tensor.set_lod(var[1])
|
||||
var = var[0]
|
||||
tensor.set_dims(var.shape)
|
||||
tensor.set(var, place)
|
||||
elif isinstance(var, float):
|
||||
scope.find_var(var_name).set_float(var)
|
||||
elif isinstance(var, int):
|
||||
scope.find_var(var_name).set_int(var)
|
||||
|
||||
for in_name, in_dup in Operator.get_op_inputs(op.type()):
|
||||
if in_name in inputs:
|
||||
if in_dup:
|
||||
sub_in = inputs[in_name]
|
||||
for item in sub_in:
|
||||
sub_in_name, sub_in_val = item[0], item[1]
|
||||
__set_input__(sub_in_name, sub_in_val)
|
||||
else:
|
||||
__set_input__(in_name, inputs[in_name])
|
||||
|
||||
|
||||
def append_input_output(block, op_proto, np_list, is_input, dtype):
|
||||
'''Insert VarDesc and generate Python variable instance'''
|
||||
proto_list = op_proto.inputs if is_input else op_proto.outputs
|
||||
|
||||
def create_var(block, name, np_list, var_proto):
|
||||
dtype = None
|
||||
shape = None
|
||||
lod_level = None
|
||||
if name not in np_list:
|
||||
assert var_proto.intermediate, "{} not found".format(name)
|
||||
else:
|
||||
np_value = np_list[name]
|
||||
if isinstance(np_value, tuple):
|
||||
dtype = np_value[0].dtype
|
||||
# output shape, lod should be infered from input.
|
||||
if is_input:
|
||||
shape = list(np_value[0].shape)
|
||||
lod_level = len(np_value[1])
|
||||
else:
|
||||
dtype = np_value.dtype
|
||||
if is_input:
|
||||
shape = list(np_value.shape)
|
||||
lod_level = 0
|
||||
return block.create_var(
|
||||
dtype=dtype, shape=shape, lod_level=lod_level, name=name)
|
||||
|
||||
var_dict = {}
|
||||
for var_proto in proto_list:
|
||||
var_name = str(var_proto.name)
|
||||
if is_input:
|
||||
if (var_name not in np_list) and var_proto.dispensable:
|
||||
continue
|
||||
assert (var_name in np_list) or (var_proto.dispensable), \
|
||||
"Missing {} as input".format(var_name)
|
||||
if var_proto.duplicable:
|
||||
assert isinstance(np_list[var_name], list), \
|
||||
"Duplicable {} should be set as list".format(var_name)
|
||||
var_list = []
|
||||
for (name, np_value) in np_list[var_name]:
|
||||
var_list.append(
|
||||
create_var(block, name, {name: np_value}, var_proto))
|
||||
var_dict[var_name] = var_list
|
||||
else:
|
||||
var_dict[var_name] = create_var(block, var_name, np_list, var_proto)
|
||||
|
||||
return var_dict
|
||||
|
||||
|
||||
def append_loss_ops(block, output_names):
|
||||
mean_inputs = map(block.var, output_names)
|
||||
# for item in mean_inputs:
|
||||
# print(item)
|
||||
# print("Item", item.dtype)
|
||||
|
||||
if len(mean_inputs) == 1:
|
||||
loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1])
|
||||
op = block.append_op(
|
||||
inputs={"X": mean_inputs}, outputs={"Out": loss}, type='mean')
|
||||
op.desc.infer_var_type(block.desc)
|
||||
op.desc.infer_shape(block.desc)
|
||||
else:
|
||||
avg_sum = []
|
||||
for cur_loss in mean_inputs:
|
||||
cur_avg_loss = block.create_var(dtype=cur_loss.dtype, shape=[1])
|
||||
op = block.append_op(
|
||||
inputs={"X": [cur_loss]},
|
||||
outputs={"Out": [cur_avg_loss]},
|
||||
type="mean")
|
||||
op.desc.infer_var_type(block.desc)
|
||||
op.desc.infer_shape(block.desc)
|
||||
avg_sum.append(cur_avg_loss)
|
||||
|
||||
loss_sum = block.create_var(dtype=avg_sum[0].dtype, shape=[1])
|
||||
op_sum = block.append_op(
|
||||
inputs={"X": avg_sum}, outputs={"Out": loss_sum}, type='sum')
|
||||
op_sum.desc.infer_var_type(block.desc)
|
||||
op_sum.desc.infer_shape(block.desc)
|
||||
|
||||
loss = block.create_var(dtype=loss_sum.dtype, shape=[1])
|
||||
op_loss = block.append_op(
|
||||
inputs={"X": loss_sum},
|
||||
outputs={"Out": loss},
|
||||
type='scale',
|
||||
attrs={'scale': 1.0 / float(len(avg_sum))})
|
||||
op_loss.desc.infer_var_type(block.desc)
|
||||
op_loss.desc.infer_shape(block.desc)
|
||||
return loss
|
Loading…
Reference in new issue