Split unittest. (#30727)
parent
caf3680bbc
commit
3491acfb1e
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.framework as framework
|
||||
from test_imperative_base import new_program_scope
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
class TestStaticSaveLoadLargeParameters(unittest.TestCase):
|
||||
def test_large_parameters_static_save(self):
|
||||
# enable static mode
|
||||
paddle.enable_static()
|
||||
LARGE_PARAM = 2**26
|
||||
with new_program_scope():
|
||||
# create network
|
||||
x = paddle.static.data(
|
||||
name="static_save_load_large_x",
|
||||
shape=[None, 10],
|
||||
dtype='float32')
|
||||
z = paddle.static.nn.fc(x, LARGE_PARAM, bias_attr=False)
|
||||
place = paddle.CPUPlace()
|
||||
exe = paddle.static.Executor(place)
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
prog = paddle.static.default_main_program()
|
||||
|
||||
base_map = {}
|
||||
for var in prog.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
t = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
# make sure all the paramerter or optimizer var have been update
|
||||
self.assertTrue(np.sum(np.abs(t)) != 0)
|
||||
base_map[var.name] = t
|
||||
|
||||
path = os.path.join("test_static_save_load_large_param",
|
||||
"static_save")
|
||||
paddle.fluid.save(prog, path)
|
||||
# set var to zero
|
||||
for var in prog.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
ten = fluid.global_scope().find_var(var.name).get_tensor()
|
||||
ten.set(np.zeros_like(np.array(ten)), place)
|
||||
|
||||
new_t = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
self.assertTrue(np.sum(np.abs(new_t)) == 0)
|
||||
|
||||
paddle.fluid.load(prog, path)
|
||||
|
||||
for var in prog.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
new_t = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
base_t = base_map[var.name]
|
||||
self.assertTrue(np.array_equal(new_t, base_t))
|
||||
|
||||
# set var to zero
|
||||
for var in prog.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
ten = fluid.global_scope().find_var(var.name).get_tensor()
|
||||
ten.set(np.zeros_like(np.array(ten)), place)
|
||||
|
||||
new_t = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
self.assertTrue(np.sum(np.abs(new_t)) == 0)
|
||||
|
||||
program_state = fluid.load_program_state(path)
|
||||
fluid.set_program_state(prog, program_state)
|
||||
for var in prog.list_vars():
|
||||
if isinstance(var, framework.Parameter) or var.persistable:
|
||||
new_t = np.array(fluid.global_scope().find_var(var.name)
|
||||
.get_tensor())
|
||||
base_t = base_map[var.name]
|
||||
self.assertTrue(np.array_equal(new_t, base_t))
|
Binary file not shown.
Loading…
Reference in new issue