revert-14324-fix_vlog
Xin Pan 7 years ago
parent 0a89650507
commit 8c11d3fed6

@ -33,7 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ + pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps 1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {

@ -46,7 +46,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
insert_pending_var(var); insert_pending_var(var);
} }
for (OpHandleBase *op : ir::GetFilteredNodes<OpHandleBase>(*graph)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {

@ -36,6 +36,7 @@ namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<OpHandleBase *> GraphOps; typedef std::vector<OpHandleBase *> GraphOps;

@ -63,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(graph)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;

@ -157,7 +157,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
} }
}; };
auto all_ops = ir::GetFilteredNodes<OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());

@ -60,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingVar(&pending_vars, ready_vars.get(), var); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : ir::GetFilteredNodes<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {

@ -38,7 +38,7 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
template <typename T> template <typename T>
std::vector<T *> GetFilteredNodes(const Graph &graph) { std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret; std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) { for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>()); if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());

@ -14,7 +14,6 @@
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import numpy as np import numpy as np
@ -91,13 +90,11 @@ class TestReaderReset(unittest.TestCase):
try: try:
data_val, label_val = parallel_exe.run(fetch_list, data_val, label_val = parallel_exe.run(fetch_list,
return_numpy=True) return_numpy=True)
sys.stderr.write('fetched %s\n' % label_val)
ins_num = data_val.shape[0] ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple( broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1)) self.ins_shape)) * label_val.reshape((ins_num, 1))
self.assertEqual(data_val.all(), broadcasted_label.all()) self.assertEqual(data_val.all(), broadcasted_label.all())
for l in label_val: for l in label_val:
sys.stderr.write('label_val: %s\n' % l[0])
self.assertFalse(data_appeared[l[0]]) self.assertFalse(data_appeared[l[0]])
data_appeared[l[0]] = True data_appeared[l[0]] = True
@ -107,7 +104,6 @@ class TestReaderReset(unittest.TestCase):
data_appeared = data_appeared[:-parallel_exe.device_count * data_appeared = data_appeared[:-parallel_exe.device_count *
self.batch_size] self.batch_size]
for i in data_appeared: for i in data_appeared:
sys.stderr.write('appeared %s\n' % i)
self.assertTrue(i) self.assertTrue(i)
if pass_count < self.test_pass_num: if pass_count < self.test_pass_num:
data_appeared = [False] * self.total_ins_num data_appeared = [False] * self.total_ins_num

Loading…
Cancel
Save