fix comments of 16410, test=develop (#16499)

* fix comments of 16410, test=develop

* modify inplace_op_inference_test according to pass interface change, test=develop
revert-16555-model_data_cryption_link_all_lib
liuwei1031 6 years ago committed by GitHub
parent 4c1ec41de3
commit 278debab71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -195,8 +195,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS inplace_op_pass op_registry proto_desc op_info memory_optimize_helper pass_builder)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)

@ -156,7 +156,6 @@ void InplacePass::ApplyImpl(ir::Graph* graph) const {
continue;
TryInplaceOpInputOutput(op, graph);
}
// graph->ResolveHazard(var_nodes_);
}
void InplacePass::InplaceModifyDesc(const std::string& var,
@ -168,7 +167,7 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
op_desc->RenameOutput(var, cache_var);
if (op_desc->Block()->HasVar(var)) op_desc->Block()->RemoveVar(var);
op_desc->Flush();
}
}
@ -265,8 +264,6 @@ void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const {
VLOG(4) << "Try to inplace op " << op->Name();
// PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
// "op_desc is nullptr");
// some pre-requirments need to meet if the op want to inplaced.
PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr");
@ -446,6 +443,7 @@ bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const {
// check if op2 depends on op1's output
bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
if (VLOG_IS_ON(4)) {
auto print_op = [&](ir::Node* op, const char* name) {
std::ostringstream os;
os << " " << name << " : " << op->Name() << " ";
@ -458,7 +456,7 @@ bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
};
print_op(op1, "OP1");
print_op(op2, "OP2");
}
if (op1 == op2) return true;
if (op_level_.at(op1) >= op_level_.at(op2)) return false;

@ -142,16 +142,15 @@ TEST(OrderedSet, FindBestFitNode) {
for (auto& node : nodes) {
pool.Insert(node.get());
}
// FIXME(liuwei1031) this API has changed,
// disable these tests temporarily
// FindNextBestFitNode
// auto* n = nodes[0].get();
// auto* cache = pool.FindBestFitNode(n);
// PADDLE_ENFORCE(cache->Name() == "a");
// cache = pool.FindNextBestFitNode(n, cache);
// PADDLE_ENFORCE(cache->Name() == "c");
// cache = pool.FindNextBestFitNode(n, cache);
// PADDLE_ENFORCE(cache->Name() == "b");
auto* n = nodes[0].get();
auto* cache = pool.FindBestFitNode(n);
ASSERT_TRUE(cache->Name() == "a" || cache->Name() == "c");
auto* cache_b = pool.FindNextBestFitNode(n, cache);
ASSERT_TRUE(cache_b->Name() != cache->Name());
ASSERT_TRUE(cache_b->Name() == "a" || cache_b->Name() == "c");
cache = pool.FindNextBestFitNode(n, cache_b);
ASSERT_TRUE(cache == nullptr);
}
} // namespace details

File diff suppressed because it is too large Load Diff

@ -56,7 +56,7 @@ proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
}
}
static DDim GetDims(const Scope& scope, const std::string& name,
static DDim GetDimsDebug(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
@ -65,9 +65,9 @@ static DDim GetDims(const Scope& scope, const std::string& name,
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
// if (UNLIKELY(!tensor.IsInitialized())) {
// return DDim({-1});
// }
if (UNLIKELY(!tensor.IsInitialized())) {
return DDim({-1});
}
return tensor.dims();
} else if (var->IsType<SelectedRows>()) {
if (get_actual_dim) {
@ -123,7 +123,7 @@ static int GetRowSize(const Scope& scope, const std::string& name) {
return -1;
}
static LoD GetLoD(const Scope& scope, const std::string& name) {
static LoD GetLoDDebug(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}});
@ -133,9 +133,9 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>();
// if (UNLIKELY(!tensor.IsInitialized())) {
// return default_lod;
// }
if (UNLIKELY(!tensor.IsInitialized())) {
return default_lod;
}
return tensor.lod();
} else {
return default_lod;
@ -274,8 +274,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
}
std::string dtype = GetDtype(*scope, var_name);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")";
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
}
}
if (i != input.second.size() - 1) {
@ -305,8 +305,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
}
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")";
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
}
}
if (i != output.second.size() - 1) {

Loading…
Cancel
Save