support cond in clone, test=develop (#22657)

* support cond in clone, test=develop

* refine code, test=develop

* refine code, test=develop

* follow comments, test=develop

* refine code, test=develop
revert-22710-feature/integrated_ps_api
Leo Chen 6 years ago committed by GitHub
parent 2143bd5738
commit b2c1be851a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"
@ -28,7 +30,7 @@ void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
std::unique_ptr<framework::ProgramDesc> PruneBackward(
std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
const framework::ProgramDesc& origin);
} // namespace framework

@ -1136,9 +1136,23 @@ All parameter, weight, gradient are variables in Paddle.
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc);
});
m.def("prune_backward", [](const framework::ProgramDesc &program) {
return PruneBackward(program);
});
m.def("prune_backward",
[](const framework::ProgramDesc &program) {
return PruneBackward(program);
},
R"DOC(
Prune the backward part of a program, mostly called in
program.clone(for_test=True).
Args:
program (ProgramDesc): The original program.
Returns:
tuple(ProgramDesc, map<int, int>): The first part is
the pruned program desc, and the second part is a map
which contains the id pair of pruned block and corresponding
origin block.
)DOC");
m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); });
m.def("grad_var_suffix",

@ -3991,18 +3991,17 @@ class Program(object):
The two code snippets above will generate and print same programs.
"""
pruned_origin_block_id_map = None
if for_test:
if self._appending_grad_times > 0:
forward_prog = Program()
forward_prog.desc = core.prune_backward(self.desc)
forward_prog.blocks = [
Block(forward_prog, i)
for i in six.moves.range(forward_prog.desc.num_blocks())
]
forward_prog._sync_with_cpp()
p = forward_prog._inference_optimize(prune_read_op=False)
else:
p = self._inference_optimize(prune_read_op=False)
forward_prog = Program()
forward_prog.desc, pruned_origin_block_id_map = core.prune_backward(
self.desc)
forward_prog.blocks = [
Block(forward_prog, i)
for i in six.moves.range(forward_prog.desc.num_blocks())
]
forward_prog._sync_with_cpp()
p = forward_prog._inference_optimize(prune_read_op=False)
else:
p = Program()
p.current_block_idx = self.current_block_idx
@ -4019,7 +4018,7 @@ class Program(object):
p._sync_with_cpp()
p._copy_param_info_from(self)
p._copy_data_info_from(self)
p._copy_data_info_from(self, pruned_origin_block_id_map)
p._copy_dist_param_info_from(self)
return p
@ -4445,9 +4444,6 @@ class Program(object):
raise TypeError("_copy_param_info_from should be invoked with "
"Program")
if len(self.blocks) != len(other.blocks):
raise ValueError("_copy_param_info_from should be invoked with two "
"program, with represent the same topology")
self.global_block()._copy_param_info_from(other.global_block())
def _copy_dist_param_info_from(self, other):
@ -4470,7 +4466,7 @@ class Program(object):
self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other):
def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
"""
Copy the information of data variables from other program.
@ -4479,6 +4475,10 @@ class Program(object):
Args:
other(Program): Other program
pruned_origin_block_id_map(dict{int:int}): A dict which maps the block id in program
self to the block id in program other. For example, {0:0, 1:1, 2:3} means block 0 in self is
cloned from block 0 in other, etc. Default is None, which means default mapped,
{0:0, 1:1,..., n:n}.
Returns:
None
@ -4487,22 +4487,24 @@ class Program(object):
raise TypeError("_copy_data_info_from should be invoked with "
"Program")
if len(self.blocks) != len(other.blocks):
raise ValueError("_copy_data_info_from should be invoked with two "
"program, with represent the same topology")
if not pruned_origin_block_id_map:
pruned_origin_block_id_map = {
i: i
for i in six.moves.range(self.desc.num_blocks())
}
# NOTE(zhiqiu): All vars in cloned program exist in original program.
# The reverse is not true, due to backward pruning.
for i, block in enumerate(other.blocks):
for i, block in enumerate(self.blocks):
other_block = other.blocks[pruned_origin_block_id_map[i]]
for var in list(block.vars.values()):
if not self.blocks[i].has_var(var.name):
continue
if var.is_data:
self.blocks[i].var(var.name).is_data = True
if var.desc.need_check_feed():
self.blocks[i].var(var.name).desc.set_need_check_feed(True)
if var.stop_gradient:
self.blocks[i].var(var.name).stop_gradient = True
other_var = other_block.var(var.name)
if other_var.is_data:
var.is_data = True
if other_var.desc.need_check_feed():
var.desc.set_need_check_feed(True)
if other_var.stop_gradient:
var.stop_gradient = True
@dygraph_not_support
def list_vars(self):

@ -128,9 +128,9 @@ def check_if_mkldnn_batchnorm_primitives_exist_in_bwd(
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
program._sync_with_cpp()
exe = fluid.Executor(place)
# Do at least 2 iterations
for i in range(2):
out = exe.run(

@ -18,7 +18,7 @@ fluid.core._set_eager_deletion_mode(-1, -1, False)
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.layers.learning_rate_scheduler import cosine_decay
from simple_nets import init_data
import math
import os
@ -161,20 +161,6 @@ def SE_ResNeXt50Small(use_feed):
return loss
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
"""
Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def optimizer(learning_rate=0.01):
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(

@ -71,6 +71,58 @@ def simple_fc_net_with_accuracy(use_feed):
return loss
def cond_net(use_feed=None):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
label = fluid.layers.data('label', shape=[1], dtype='int64')
prediction = fluid.layers.fc(input=x, size=1, act=None)
def loss1(pred, label):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss')
return avg_loss
def loss2(pred, label):
loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss')
return avg_loss
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = (two == 0)
avg_loss = fluid.layers.case([(pred, lambda: loss1(prediction, label))],
lambda: loss2(prediction, label))
return avg_loss
def optimization_in_cond_net(with_optimize=False):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
label = fluid.layers.data('label', shape=[1], dtype='int64')
prediction = fluid.layers.fc(input=x, size=1, act=None)
def loss1(opt, pred, label, with_optimize):
x = fluid.layers.data(name="x", shape=[4], dtype='float32')
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_cross_entropy_loss')
if with_optimize:
opt.minimize(avg_loss)
return avg_loss
def loss2(opt, pred, label, with_optimize):
loss = fluid.layers.softmax_with_cross_entropy(logits=pred, label=label)
avg_loss = fluid.layers.mean(loss, name='mean_softmax_loss')
if with_optimize:
opt.minimize(avg_loss)
return avg_loss
sgd = fluid.optimizer.SGD(learning_rate=0.1)
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = (two == 0)
avg_loss = fluid.layers.case(
[(pred, lambda: loss1(sgd, prediction, label, with_optimize))],
lambda: loss2(sgd, prediction, label, with_optimize))
return avg_loss
class TestProgramPruneBackward(unittest.TestCase):
def program_compare(self, program_a, program_b):
assert isinstance(
@ -99,19 +151,24 @@ class TestProgramPruneBackward(unittest.TestCase):
test_prog_orig = main_program.clone(for_test=True)
optimizer().minimize(loss)
test_prog_prune = main_program.clone(for_test=True)
self.program_compare(test_prog_orig, test_prog_prune)
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
loss_data_prune, = exe.run(test_prog_prune,
feed=feed_dict,
fetch_list=[loss.name])
loss_data_orig, = exe.run(test_prog_orig,
feed=feed_dict,
fetch_list=[loss.name])
self.assertEqual(loss_data_orig, loss_data_prune)
for place in places:
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loss_data_prune, = exe.run(test_prog_prune,
feed=feed_dict,
fetch_list=[loss.name])
loss_data_orig, = exe.run(test_prog_orig,
feed=feed_dict,
fetch_list=[loss.name])
self.assertEqual(loss_data_orig, loss_data_prune)
def test_simple_fc_net(self):
def optimizer():
@ -198,6 +255,48 @@ class TestProgramPruneBackward(unittest.TestCase):
self.check_prune_correctness(
method=lstm_net, feed_dict=feed_data, optimizer=optimizer)
def test_cond(self):
def optimizer():
optimizer = fluid.optimizer.SGD(learning_rate=0.01)
return optimizer
with self.program_scope_guard():
x_in = np.random.random(size=(10, 4)).astype('float32')
label_in = np.random.randint(1, size=(10, 1)).astype('int64')
feed_dict = {'x': x_in, 'label': label_in}
self.check_prune_correctness(
method=cond_net, feed_dict=feed_dict, optimizer=optimizer)
def test_optimization_in_cond(self):
x_in = np.random.random(size=(10, 4)).astype('float32')
label_in = np.random.randint(1, size=(10, 1)).astype('int64')
feed_dict = {'x': x_in, 'label': label_in}
with self.program_scope_guard():
loss = optimization_in_cond_net(False)
main_program = fluid.default_main_program()
test_prog_orig = main_program.clone(for_test=True)
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loss_data_orig, = exe.run(test_prog_orig,
feed=feed_dict,
fetch_list=[loss.name])
with self.program_scope_guard():
loss = optimization_in_cond_net(True)
main_program = fluid.default_main_program()
test_prog_prune = main_program.clone(for_test=True)
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loss_data_prune, = exe.run(test_prog_prune,
feed=feed_dict,
fetch_list=[loss.name])
self.program_compare(test_prog_orig, test_prog_prune)
self.assertEqual(loss_data_orig, loss_data_prune)
@contextlib.contextmanager
def program_scope_guard(self):
prog = fluid.Program()
@ -205,7 +304,8 @@ class TestProgramPruneBackward(unittest.TestCase):
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
with fluid.unique_name.guard():
yield
if __name__ == '__main__':

Loading…
Cancel
Save