You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/tests/unittests/utils.py

191 lines
5.9 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.framework import _dygraph_guard
import paddle.fluid as fluid
from paddle.fluid.framework import Variable
import paddle.fluid.dygraph.jit as jit
from paddle.fluid.dygraph.jit import extract_vars
import numpy as np
import os
import time
__all__ = ['DyGraphProgramDescTracerTestHelper', ]
def is_equal_program(prog1, prog2):
with _dygraph_guard(None):
return _is_equal_program(prog1, prog2)
def _is_equal_program(prog1, prog2):
block_num = prog1.num_blocks
if block_num != prog2.num_blocks:
return False
for block_id in range(block_num):
block1 = prog1.block(block_id)
block2 = prog2.block(block_id)
if len(block1.ops) != len(block2.ops):
return False
if len(block1.vars) != len(block2.vars):
return False
for op1, op2 in zip(block1.ops, block2.ops):
if op1.input_arg_names != op2.input_arg_names:
return False
if op1.output_arg_names != op2.output_arg_names:
return False
attr1 = op1.all_attrs()
attr2 = op2.all_attrs()
if len(attr1) != len(attr2):
return False
for key1, value1 in attr1.items():
if key1 not in attr2:
return False
if value1 != attr2.get(key1):
return False
for var1 in block1.vars.values():
if var1.name not in block2.vars:
return False
var2 = block2.vars.get(var1.name)
if var1.name != var2.name:
return False
if var1.type != var2.type:
return False
if var1.dtype != var2.dtype:
return False
if var1.lod_level != var2.lod_level:
return False
if var1.persistable != var2.persistable:
return False
return True
def load_dygraph_vars_to_scope(model_path, scope, place):
def load_dict_to_scope(scope, dictionary):
if scope is None:
scope = fluid.global_scope()
for k, v in dictionary.items():
dst_t = scope.var(k).get_tensor()
src_t = v.value().get_tensor()
dst_t.set(np.array(src_t), place)
dst_t.set_lod(src_t.lod())
param_dict, opti_dict = fluid.load_dygraph(model_path)
if param_dict:
load_dict_to_scope(scope, param_dict)
if opti_dict:
load_dict_to_scope(scope, opti_dict)
class DyGraphProgramDescTracerTestHelper(object):
def __init__(self,
module,
unittest_obj,
model_path=None,
scope=None,
place=None):
self.module = module
self.unittest_obj = unittest_obj
self.scope = fluid.Scope() if scope is None else scope
self.model_path = model_path
if model_path is None:
millis = int(round(time.time() * 1000))
self.model_path = "id_{}_{}".format(id(module), millis)
self.place = place
if place is None:
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.program = None
self.executor = fluid.Executor(self.place)
def _remove_model_path(self):
if os.path.exists(self.model_path + ".pdparams"):
os.remove(self.model_path + ".pdparams")
if os.path.exists(self.model_path + ".pdopt"):
os.remove(self.model_path + ".pdopt")
def _run_static_graph(self, inputs, feed_names, fetch_names):
var_list = extract_vars(inputs)
assert len(var_list) == len(feed_names)
feed_dict = {}
for name, var in zip(feed_names, var_list):
feed_dict[name] = np.array(var.value().get_tensor())
with fluid.scope_guard(self.scope):
with _dygraph_guard(None):
return self.executor.run(self.program,
feed=feed_dict,
fetch_list=fetch_names)
def run(self, inputs, feed_names, fetch_names):
out_dygraph, program = jit.trace(
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
if self.program is not None:
self.unittest_obj.assertTrue(
is_equal_program(self.program, program))
self.program = program
fluid.save_dygraph(self.module.state_dict(), self.model_path)
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
self._remove_model_path()
out_static_graph = self._run_static_graph(inputs, feed_names,
fetch_names)
if not isinstance(out_dygraph, (list, tuple)):
assert len(out_static_graph) == 1
out_static_graph = out_static_graph[0]
return out_dygraph, out_static_graph
def assertEachVar(self, out_dygraph, out_static_graph, func=None):
if func is None:
func = lambda x, y: np.array_equal(x, y)
if not isinstance(out_dygraph, (list, tuple)):
out_dygraph = [out_dygraph]
if not isinstance(out_static_graph, (list, tuple)):
out_static_graph = [out_static_graph]
for v1, v2 in zip(out_dygraph, out_static_graph):
self.unittest_obj.assertTrue(func(v1.numpy(), v2))