support forward hook for dygraph (#22443)
* support forward hook for dygraph, test=develop * add optest for forward_hook in dygraph, test=develop * add optest, test=develop * polish code, test=develop * add sample code, test=develop * rename forwrd_hook to forward_post_hook, test=develop * fix the api description, test=develop * fix api description, test=developrevert-23830-2.0-beta
parent
a62599a888
commit
166a1ae902
@ -0,0 +1,197 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid.dygraph.base as base
|
||||
|
||||
from test_imperative_lod_tensor_to_selected_rows import SimpleNet
|
||||
|
||||
call_forward_hook = False
|
||||
call_forward_pre_hook = False
|
||||
|
||||
|
||||
def forward_hook(layer, input, output):
|
||||
global call_forward_hook
|
||||
call_forward_hook = True
|
||||
|
||||
|
||||
def forward_pre_hook(layer, input):
|
||||
global call_forward_pre_hook
|
||||
call_forward_pre_hook = True
|
||||
|
||||
|
||||
def forward_hook1(layer, input, output):
|
||||
return output * 2
|
||||
|
||||
|
||||
def forward_pre_hook1(layer, input):
|
||||
input_return = (input[0] * 2, input[1])
|
||||
return input_return
|
||||
|
||||
|
||||
class Test_Forward_Hook(unittest.TestCase):
|
||||
# test forward_pre_hook and forward_hook that have return value
|
||||
def test_forward_hook_return_value(self):
|
||||
seed = 90
|
||||
|
||||
places = [fluid.CPUPlace()]
|
||||
if core.is_compiled_with_cuda():
|
||||
places.append(fluid.CUDAPlace(0))
|
||||
|
||||
for place in places:
|
||||
with fluid.dygraph.guard(place):
|
||||
fluid.default_startup_program().random_seed = seed
|
||||
fluid.default_main_program().random_seed = seed
|
||||
backward_strategy = fluid.dygraph.BackwardStrategy()
|
||||
backward_strategy.sort_sum_gradient = True
|
||||
|
||||
input_word = np.array(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8]).reshape(6, 3).astype('int64')
|
||||
input_word1 = input_word * 2
|
||||
input_word = input_word.reshape((-1, 3, 1))
|
||||
input_word1 = input_word1.reshape((-1, 3, 1))
|
||||
y_data = np.array(
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
9]).reshape(6, 3).astype('int64')
|
||||
y_data = y_data.reshape((-1, 1))
|
||||
|
||||
input = base.to_variable(input_word)
|
||||
input1 = base.to_variable(input_word1)
|
||||
y = base.to_variable(y_data)
|
||||
|
||||
simplenet = SimpleNet(
|
||||
hidden_size=20,
|
||||
vocab_size=32,
|
||||
num_steps=3,
|
||||
init_scale=0.1,
|
||||
is_sparse=False,
|
||||
dtype="float32")
|
||||
|
||||
# origin, don't register any hook
|
||||
outs_origin = simplenet(input, y)
|
||||
outs_origin1 = simplenet(input1, y)
|
||||
|
||||
# register forward_pre_hook
|
||||
forward_pre_hook_handle1 = simplenet.register_forward_pre_hook(
|
||||
forward_pre_hook1)
|
||||
outs_pre_hook = simplenet(input, y)
|
||||
self.assertTrue(
|
||||
np.array_equal(outs_pre_hook.numpy(), outs_origin1.numpy()))
|
||||
|
||||
# remove forward_pre_hook
|
||||
forward_pre_hook_handle1.remove()
|
||||
outs_pre_hook = simplenet(input, y)
|
||||
self.assertTrue(
|
||||
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))
|
||||
|
||||
# register forward_hook
|
||||
forward_hook_handle1 = simplenet.register_forward_post_hook(
|
||||
forward_hook1)
|
||||
outs_forward_hook = simplenet(input, y)
|
||||
self.assertTrue(
|
||||
np.array_equal(outs_forward_hook.numpy(),
|
||||
outs_origin.numpy() * 2))
|
||||
|
||||
# remove forward_hook
|
||||
forward_hook_handle1.remove()
|
||||
outs_forward_hook = simplenet(input, y)
|
||||
self.assertTrue(
|
||||
np.array_equal(outs_forward_hook.numpy(),
|
||||
outs_origin.numpy()))
|
||||
|
||||
# test forward_pre_hook and forward_hook that don't have return value
|
||||
def test_forward_hook(self):
|
||||
seed = 90
|
||||
|
||||
places = [fluid.CPUPlace()]
|
||||
if core.is_compiled_with_cuda():
|
||||
places.append(fluid.CUDAPlace(0))
|
||||
|
||||
for place in places:
|
||||
with fluid.dygraph.guard(place):
|
||||
fluid.default_startup_program().random_seed = seed
|
||||
fluid.default_main_program().random_seed = seed
|
||||
backward_strategy = fluid.dygraph.BackwardStrategy()
|
||||
backward_strategy.sort_sum_gradient = True
|
||||
|
||||
global call_forward_hook
|
||||
global call_forward_pre_hook
|
||||
|
||||
input_word = np.array(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8]).reshape(6, 3).astype('int64')
|
||||
input_word = input_word.reshape((-1, 3, 1))
|
||||
y_data = np.array(
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
9]).reshape(6, 3).astype('int64')
|
||||
y_data = y_data.reshape((-1, 1))
|
||||
|
||||
input = base.to_variable(input_word)
|
||||
y = base.to_variable(y_data)
|
||||
|
||||
simplenet = SimpleNet(
|
||||
hidden_size=20,
|
||||
vocab_size=32,
|
||||
num_steps=3,
|
||||
init_scale=0.1,
|
||||
is_sparse=False,
|
||||
dtype="float32")
|
||||
|
||||
# origin, don't register any hook
|
||||
outs_origin = simplenet(input, y)
|
||||
self.assertFalse(call_forward_hook)
|
||||
self.assertFalse(call_forward_pre_hook)
|
||||
|
||||
# register forward_hook and forward_pre_hook
|
||||
forward_hook_handle = simplenet.register_forward_post_hook(
|
||||
forward_hook)
|
||||
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
|
||||
forward_pre_hook)
|
||||
outs_hook = simplenet(input, y)
|
||||
self.assertTrue(call_forward_hook)
|
||||
self.assertTrue(call_forward_pre_hook)
|
||||
|
||||
outs_hook = simplenet(input, y)
|
||||
self.assertTrue(call_forward_hook)
|
||||
self.assertTrue(call_forward_pre_hook)
|
||||
|
||||
# remove forward_hook
|
||||
forward_hook_handle.remove()
|
||||
call_forward_hook = False
|
||||
call_forward_pre_hook = False
|
||||
outs_remove_forward_hook = simplenet(input, y)
|
||||
self.assertFalse(call_forward_hook)
|
||||
self.assertTrue(call_forward_pre_hook)
|
||||
|
||||
# remove forward_pre_hook
|
||||
forward_pre_hook_handle.remove()
|
||||
call_forward_hook = False
|
||||
call_forward_pre_hook = False
|
||||
outs_remove_hook = simplenet(input, y)
|
||||
self.assertFalse(call_forward_hook)
|
||||
self.assertFalse(call_forward_pre_hook)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue