[Dy2stat] Fix PaddleGan Deoldify Model Dy2stat Problems (#29226)
This PR fixes several problems in dy2stat for Deoldify model in PaddleGan. In model, software engineer wrote if x.shape == y.shape, the Tenser shape is a tuple in dygraph so the == returns True/False, but in static graph the == becomes element-wise comparison, which is a different behavior. In this PR we reduce the element-wise comparison result. If software engineer write computations which uses parameters in hooks, the static graph can loss the parameter variable because we put param_guard at forward of a Layer. In this PR we made param_guard cover pre-hook and post-hook. In PaddleGan, software engineer calculated some parameter values in __init__ by running some dygraph code. Those code also run during dy2stat. So some variables may be assign as a VarBase (Tensor) first and then Variable, which raised an error. We fixed the bug in this PR by handling the case. TODO: We just added testcase for the 1. shape comparison. Should add test case for 2. and 3. But since we are chasing 2.0RC, I will do it in the near future PRrevert-31562-mean
parent
fc80d2e09c
commit
aec05d811c
@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import unittest
|
||||
|
||||
|
||||
class TestConvertShapeCompare(unittest.TestCase):
|
||||
def test_non_variable(self):
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3),
|
||||
True)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3),
|
||||
False)
|
||||
|
||||
def error_func():
|
||||
"""
|
||||
Function used to test that comparison doesn't run after first False
|
||||
"""
|
||||
raise ValueError("Used for test")
|
||||
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(
|
||||
1, ">", 2, "<=", lambda: error_func()), False)
|
||||
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "in",
|
||||
[1, 2, 3]), True)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "not in",
|
||||
[1, 2, 3]), False)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3),
|
||||
False)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is not",
|
||||
[1, 2, 3]), True)
|
||||
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare([1, 2], "==", [1, 2],
|
||||
"!=", [1, 2, 3]), True)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare([1, 2], "!=", [1, 2, 3],
|
||||
"==", [1, 2]), False)
|
||||
|
||||
def test_variable(self):
|
||||
paddle.enable_static()
|
||||
with paddle.static.program_guard(paddle.static.Program(),
|
||||
paddle.static.Program()):
|
||||
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
|
||||
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is not",
|
||||
y), True)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(x, "is not", x,
|
||||
"is not", y), False)
|
||||
self.assertEqual(
|
||||
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y),
|
||||
False)
|
||||
|
||||
eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y)
|
||||
not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y)
|
||||
long_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", x,
|
||||
"!=", y)
|
||||
|
||||
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
|
||||
) else paddle.CPUPlace()
|
||||
exe = paddle.static.Executor(place)
|
||||
x_y_eq_out = exe.run(feed={
|
||||
"x": np.ones([3, 2]).astype(np.float32),
|
||||
"y": np.ones([3, 2]).astype(np.float32)
|
||||
},
|
||||
fetch_list=[eq_out, not_eq_out, long_eq_out])
|
||||
np.testing.assert_array_equal(
|
||||
np.array(x_y_eq_out), np.array([[True], [False], [False]]))
|
||||
|
||||
set_a_zero = np.ones([3, 2]).astype(np.float32)
|
||||
set_a_zero[0][0] = 0.0
|
||||
x_y_not_eq_out = exe.run(
|
||||
feed={
|
||||
"x": np.ones([3, 2]).astype(np.float32),
|
||||
"y": set_a_zero
|
||||
},
|
||||
fetch_list=[eq_out, not_eq_out, long_eq_out])
|
||||
np.testing.assert_array_equal(
|
||||
np.array(x_y_not_eq_out), np.array([[False], [True], [True]]))
|
||||
paddle.disable_static()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue