[Dy2stat] Fix Memory Optimization in run_program_op and Add SimNet as Unit Test (#25383)

Add Similarity Net as unit test. During the unit test, we found three problems:

1. The run_program_op has memory optimization error when running dy2stat net multiple times.
2. The support for SelectedRows can cause problem in dy2stat.
3. The return grammar has problem.

This PR fixes the 1. problem but modify codes for the 2. 3. problems to make PR smaller. I will fix those two problems in the next PR(s)
fix_copy_if_different
Huihuang Zheng 5 years ago committed by GitHub
parent c42d662e2a
commit f9ac5fb992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,7 +91,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif()

@ -17,10 +17,12 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
@ -149,14 +151,46 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
}
}
static void AppendSkipDeletionVars(
std::vector<std::string> *all_vars,
const std::vector<std::string> &append_vars) {
static void AppendSkipDeletionVars(const std::vector<std::string> &append_vars,
std::vector<std::string> *all_vars) {
for (auto &var : append_vars) {
all_vars->emplace_back(var);
}
}
static void AppendSafeEagerDeletionSkipVars(
const framework::ProgramDesc &program,
std::vector<std::string> *skip_vars) {
const framework::BlockDesc &block = program.Block(0);
const std::vector<framework::OpDesc *> &all_ops = block.AllOps();
std::unordered_set<std::string> grad_op_output;
std::unordered_set<std::string> grad_op_input;
for (const framework::OpDesc *op : all_ops) {
int op_role = BOOST_GET_CONST(
int, op->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if ((op_role & static_cast<int>(framework::OpRole::kBackward)) == 0) {
continue;
}
for (const std::string &in_arg_name : op->InputArgumentNames()) {
grad_op_input.emplace(in_arg_name);
}
for (const std::string &out_arg_name : op->OutputArgumentNames()) {
grad_op_output.emplace(out_arg_name);
}
}
// For the grad op input variables, if it is not output of grad_op, it may
// be output of forward op and we should set the variables as skip_var to
// prevent it being deleted when grad op is called multiple times.
for (const std::string &var_name : grad_op_input) {
if (grad_op_output.find(var_name) == grad_op_output.end()) {
skip_vars->emplace_back(var_name);
}
}
}
} // namespace details
template <typename DeviceContext, typename T>
@ -192,7 +226,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// skip delete vars
std::vector<std::string> skip_vars;
details::AppendSkipDeletionVars(&skip_vars, output_var_names);
details::AppendSkipDeletionVars(output_var_names, &skip_vars);
VLOG(2) << "Prepare to skip " << skip_vars.size()
<< " var(s): " << string::join_strings(skip_vars, ' ');
@ -261,20 +295,21 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
out_scope_vec->size(), 1,
platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should only hold one scope."));
auto &scope = *(out_scope_vec->front());
// Step 2. prepare executor and scope
framework::Executor exe(ctx.GetPlace());
// skip delete vars
std::vector<std::string> skip_vars;
details::AppendSkipDeletionVars(&skip_vars, input_grad_var_names);
details::AppendSkipDeletionVars(&skip_vars, param_grad_names);
details::AppendSkipDeletionVars(input_grad_var_names, &skip_vars);
details::AppendSkipDeletionVars(param_grad_names, &skip_vars);
details::AppendSafeEagerDeletionSkipVars(*program, &skip_vars);
VLOG(2) << "Prepare to skip " << skip_vars.size()
<< " var(s): " << string::join_strings(skip_vars, ' ');
auto exe_ctx = exe.Prepare(*program, 0, skip_vars);
auto &scope = *(out_scope_vec->front());
details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names,
&scope);

@ -0,0 +1,174 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import paddle
import paddle.fluid as fluid
import random
import unittest
from paddle.fluid.dygraph import ProgramTranslator
from simnet_dygraph_model import BOW, HingeLoss
SEED = 102
random.seed(SEED)
def create_conf_dict():
conf_dict = {}
conf_dict["task_mode"] = "train"
conf_dict["net"] = {"emb_dim": 128, "bow_dim": 128, "hidden_dim": 128}
conf_dict["loss"] = {"margin": 0.1}
return conf_dict
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Total examples' number in batch for training.")
parser.add_argument(
"--seq_len", type=int, default=32, help="The length of each sentence.")
parser.add_argument(
"--epoch", type=int, default=1, help="The number of training epoch.")
parser.add_argument(
"--fake_sample_size",
type=int,
default=128,
help="The number of samples of fake data.")
args = parser.parse_args([])
return args
args = parse_args()
def fake_vocabulary():
vocab = {}
vocab["<unk>"] = 0
for i in range(26):
c = chr(ord('a') + i)
vocab[c] = i + 1
return vocab
vocab = fake_vocabulary()
class FakeReaderProcessor(object):
def __init__(self, args, vocab):
self.vocab = vocab
self.seq_len = args.seq_len
self.sample_size = args.fake_sample_size
self.data_samples = []
for i in range(self.sample_size):
query = [random.randint(0, 26) for i in range(self.seq_len)]
pos_title = query[:]
neg_title = [26 - q for q in query]
self.data_samples.append(
np.array([query, pos_title, neg_title]).astype(np.int64))
def get_reader(self, mode, epoch=0):
def reader_with_pairwise():
if mode == "train":
for i in range(self.sample_size):
yield self.data_samples[i]
return reader_with_pairwise
simnet_process = FakeReaderProcessor(args, vocab)
def train(conf_dict, to_static):
"""
train process
"""
program_translator = ProgramTranslator()
program_translator.enable(to_static)
# Get device
if fluid.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
conf_dict['dict_size'] = len(vocab)
conf_dict['seq_len'] = args.seq_len
net = BOW(conf_dict)
loss = HingeLoss(conf_dict)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
parameter_list=net.parameters())
metric = fluid.metrics.Auc(name="auc")
global_step = 0
losses = []
train_loader = fluid.io.DataLoader.from_generator(
capacity=16,
return_list=True,
iterable=True,
use_double_buffer=True)
get_train_examples = simnet_process.get_reader(
"train", epoch=args.epoch)
train_loader.set_sample_list_generator(
paddle.batch(
get_train_examples, batch_size=args.batch_size), place)
for left, pos_right, neg_right in train_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
neg_right = fluid.layers.reshape(neg_right, shape=[-1, 1])
net.train()
global_step += 1
left_feat, pos_score = net(left, pos_right)
pred = pos_score
_, neg_score = net(left, neg_right)
avg_cost = loss.compute(pos_score, neg_score)
#avg_cost = loss.compute(pos_score, pos_score)
losses.append(np.mean(avg_cost.numpy()))
avg_cost.backward()
optimizer.minimize(avg_cost)
net.clear_gradients()
return losses
class TestSimnet(unittest.TestCase):
def test_dygraph_static_same_loss(self):
if fluid.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
conf_dict = create_conf_dict()
dygraph_loss = train(conf_dict, to_static=False)
static_loss = train(conf_dict, to_static=True)
self.assertEqual(len(dygraph_loss), len(static_loss))
for i in range(len(dygraph_loss)):
self.assertAlmostEqual(dygraph_loss[i], static_loss[i])
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save