add coalesce_tensor into white list when checking re-creation of parameters (#31800)
parent
a70de87d76
commit
4046f1303a
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
def create_model():
|
||||||
|
hidden_size = 32
|
||||||
|
bilstm = paddle.nn.LSTM(
|
||||||
|
hidden_size, hidden_size, num_layers=1, direction='bidirectional')
|
||||||
|
return bilstm
|
||||||
|
|
||||||
|
|
||||||
|
class TestRNNProgramClone(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
paddle.enable_static()
|
||||||
|
|
||||||
|
def test_rnn_with_cudnn_clone(self):
|
||||||
|
train_program = paddle.static.Program()
|
||||||
|
test_program = paddle.static.Program()
|
||||||
|
startup_prog = paddle.static.Program()
|
||||||
|
|
||||||
|
# test a typical case in static graph usage: create two nearly
|
||||||
|
# identical program with a shared startup program to share their
|
||||||
|
# parameters
|
||||||
|
#
|
||||||
|
# when creating a parameter, the name is checked. If there is already
|
||||||
|
# a parameter with the same name, which is the output of a operator
|
||||||
|
# (i.e. its creator), its re-creation is skipped.
|
||||||
|
#
|
||||||
|
# but if that parameter has been the output of more than one operator,
|
||||||
|
# an exception is raised. For special cases, white list is added.
|
||||||
|
# flattening rnn's parameters for the need to call cudnn kernel is such
|
||||||
|
# a case.
|
||||||
|
with paddle.static.program_guard(train_program, startup_prog):
|
||||||
|
with paddle.fluid.unique_name.guard():
|
||||||
|
bilstm = create_model()
|
||||||
|
|
||||||
|
with paddle.fluid.program_guard(test_program, startup_prog):
|
||||||
|
with paddle.fluid.unique_name.guard():
|
||||||
|
bilstm = create_model()
|
Loading…
Reference in new issue