[Dygraph to static graph]JIT/Trace (#20775)
* jit/trace 1st version, test=develop * add more unittests, test=developyaoxuefeng
parent
6e6eab07e8
commit
8ff6b289bd
@ -0,0 +1,2 @@
|
|||||||
|
cc_library(op_desc_meta SRCS op_desc_meta.cc DEPS proto_desc layer)
|
||||||
|
cc_library(program_desc_tracer SRCS program_desc_tracer.cc DEPS op_desc_meta)
|
@ -0,0 +1,50 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
|
||||||
|
#include "paddle/fluid/framework/op_info.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace imperative {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
OpDescMeta::OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
|
||||||
|
const NameVarBaseMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs)
|
||||||
|
: type_(type), attrs_(attrs) {
|
||||||
|
auto *proto = framework::OpInfoMap::Instance().GetNullable(type_);
|
||||||
|
if (proto && proto->Checker()) {
|
||||||
|
proto->Checker()->Check(&attrs_);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : inputs) {
|
||||||
|
inputs_[pair.first].assign(pair.second.begin(), pair.second.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : outputs) {
|
||||||
|
outputs_[pair.first].assign(pair.second.begin(), pair.second.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string &OpDescMeta::Type() const { return type_; }
|
||||||
|
|
||||||
|
const WeakNameVarBaseMap &OpDescMeta::Inputs() const { return inputs_; }
|
||||||
|
|
||||||
|
const WeakNameVarBaseMap &OpDescMeta::Outputs() const { return outputs_; }
|
||||||
|
|
||||||
|
const framework::AttributeMap &OpDescMeta::Attrs() const { return attrs_; }
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace imperative
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/imperative/layer.h"
|
||||||
|
#include "paddle/fluid/imperative/type_defs.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace imperative {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
class OpDescMeta {
|
||||||
|
public:
|
||||||
|
OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
|
||||||
|
const NameVarBaseMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs);
|
||||||
|
|
||||||
|
const std::string &Type() const;
|
||||||
|
|
||||||
|
const WeakNameVarBaseMap &Inputs() const;
|
||||||
|
|
||||||
|
const WeakNameVarBaseMap &Outputs() const;
|
||||||
|
|
||||||
|
const framework::AttributeMap &Attrs() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string type_;
|
||||||
|
WeakNameVarBaseMap inputs_;
|
||||||
|
WeakNameVarBaseMap outputs_;
|
||||||
|
framework::AttributeMap attrs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace imperative
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,235 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace imperative {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) {
|
||||||
|
name_prefix_ = name_prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgramDescTracer::SetFeedVars(
|
||||||
|
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
||||||
|
std::vector<std::string> feed_names) {
|
||||||
|
feed_vars_.clear();
|
||||||
|
|
||||||
|
if (feed_names.empty()) {
|
||||||
|
feed_names.reserve(feed_vars.size());
|
||||||
|
for (auto &var : feed_vars) {
|
||||||
|
feed_names.emplace_back(var->Name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(),
|
||||||
|
"The feeded variable names number must be equal to the "
|
||||||
|
"feeded variable number");
|
||||||
|
|
||||||
|
for (size_t i = 0; i < feed_names.size(); ++i) {
|
||||||
|
feed_vars_[feed_vars[i]] = feed_names[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgramDescTracer::SetFetchVars(
|
||||||
|
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
||||||
|
std::vector<std::string> fetch_names) {
|
||||||
|
fetch_vars_.clear();
|
||||||
|
|
||||||
|
if (fetch_names.empty()) {
|
||||||
|
fetch_names.reserve(fetch_vars.size());
|
||||||
|
for (auto &var : fetch_vars) {
|
||||||
|
fetch_names.emplace_back(var->Name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(),
|
||||||
|
"The fetched variable names number must be equal to the "
|
||||||
|
"fetched variable number");
|
||||||
|
for (size_t i = 0; i < fetch_names.size(); ++i) {
|
||||||
|
fetch_vars_[fetch_vars[i]] = fetch_names[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgramDescTracer::InsertOp(const std::string &type,
|
||||||
|
const NameVarBaseMap &inputs,
|
||||||
|
const NameVarBaseMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs) {
|
||||||
|
ops_.emplace_back(new OpDescMeta(type, inputs, outputs, attrs));
|
||||||
|
auto &new_op = ops_.back();
|
||||||
|
for (auto &pair : new_op->Inputs()) {
|
||||||
|
for (auto &var : pair.second) {
|
||||||
|
InsertVarIfNotExist(var.lock());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : new_op->Outputs()) {
|
||||||
|
for (auto &var : pair.second) {
|
||||||
|
InsertVarIfNotExist(var.lock());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
||||||
|
const {
|
||||||
|
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
|
||||||
|
auto *block = prog->MutableBlock(0);
|
||||||
|
|
||||||
|
size_t var_num = vars_.size();
|
||||||
|
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
|
||||||
|
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
|
||||||
|
var_desc_to_var_base;
|
||||||
|
|
||||||
|
for (auto &pair : vars_) {
|
||||||
|
size_t var_id = pair.second.first;
|
||||||
|
PADDLE_ENFORCE_LT(var_id, var_num);
|
||||||
|
var_descs[var_id] = pair.second.second.get();
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
|
||||||
|
var_desc_to_var_base[var_descs[var_id]] = pair.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<std::string> existing_var_names;
|
||||||
|
for (auto *var_desc : var_descs) {
|
||||||
|
if (var_desc->Persistable()) {
|
||||||
|
existing_var_names.insert(var_desc->Name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : feed_vars_) {
|
||||||
|
existing_var_names.insert(pair.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : fetch_vars_) {
|
||||||
|
existing_var_names.insert(pair.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t counter = 0;
|
||||||
|
auto generate_unique_name = [&]() -> std::string {
|
||||||
|
do {
|
||||||
|
auto name = name_prefix_ + std::to_string(counter++);
|
||||||
|
if (existing_var_names.count(name) == 0) {
|
||||||
|
existing_var_names.insert(name);
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
} while (counter > 0);
|
||||||
|
PADDLE_THROW("Too many vars in the program");
|
||||||
|
};
|
||||||
|
|
||||||
|
std::map<std::weak_ptr<VarBase>, std::string,
|
||||||
|
std::owner_less<std::weak_ptr<VarBase>>>
|
||||||
|
var_to_name;
|
||||||
|
for (auto *var_desc : var_descs) {
|
||||||
|
auto var_name = var_desc->Name();
|
||||||
|
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
|
||||||
|
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
|
||||||
|
if (feed_vars_.count(var_base) > 0) {
|
||||||
|
var_name = feed_vars_.at(var_base);
|
||||||
|
} else if (fetch_vars_.count(var_base) > 0) {
|
||||||
|
var_name = fetch_vars_.at(var_base);
|
||||||
|
} else if (!var_desc->Persistable()) {
|
||||||
|
var_name = generate_unique_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *new_var_desc = block->Var(var_name);
|
||||||
|
*new_var_desc = *var_desc;
|
||||||
|
new_var_desc->SetName(std::move(var_name));
|
||||||
|
var_to_name[var_base] = new_var_desc->Name();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
auto *op_desc = block->AppendOp();
|
||||||
|
op_desc->SetType(op->Type());
|
||||||
|
op_desc->SetAttrMap(op->Attrs());
|
||||||
|
|
||||||
|
for (auto &pair : op->Inputs()) {
|
||||||
|
std::vector<std::string> names;
|
||||||
|
names.reserve(pair.second.size());
|
||||||
|
for (auto &var : pair.second) {
|
||||||
|
auto iter = var_to_name.find(var);
|
||||||
|
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
||||||
|
"Cannot find input variable");
|
||||||
|
names.emplace_back(iter->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_desc->SetInput(pair.first, std::move(names));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pair : op->Outputs()) {
|
||||||
|
std::vector<std::string> names;
|
||||||
|
names.reserve(pair.second.size());
|
||||||
|
for (auto &var : pair.second) {
|
||||||
|
auto iter = var_to_name.find(var);
|
||||||
|
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
||||||
|
"Cannot find output variable");
|
||||||
|
names.emplace_back(iter->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_desc->SetOutput(pair.first, std::move(names));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prog->Flush();
|
||||||
|
return prog;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgramDescTracer::InsertVarIfNotExist(
|
||||||
|
const std::shared_ptr<VarBase> &new_var) {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(new_var);
|
||||||
|
if (vars_.count(new_var) != 0) return;
|
||||||
|
|
||||||
|
size_t var_id = vars_.size();
|
||||||
|
auto new_var_desc = new framework::VarDesc("");
|
||||||
|
vars_[new_var] =
|
||||||
|
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
|
||||||
|
|
||||||
|
if (new_var->Persistable()) {
|
||||||
|
new_var_desc->SetName(new_var->Name());
|
||||||
|
new_var_desc->SetPersistable(true);
|
||||||
|
} else {
|
||||||
|
new_var_desc->SetPersistable(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &inner_var = new_var->Var();
|
||||||
|
PADDLE_ENFORCE_EQ(inner_var.IsInitialized(), true);
|
||||||
|
if (inner_var.IsType<framework::LoDTensor>()) {
|
||||||
|
const auto &tensor = inner_var.Get<framework::LoDTensor>();
|
||||||
|
new_var_desc->SetType(framework::proto::VarType::LOD_TENSOR);
|
||||||
|
new_var_desc->SetShape(framework::vectorize<int64_t>(tensor.dims()));
|
||||||
|
new_var_desc->SetLoDLevel(tensor.lod().size());
|
||||||
|
if (tensor.IsInitialized()) {
|
||||||
|
new_var_desc->SetDataType(tensor.type());
|
||||||
|
} else {
|
||||||
|
new_var_desc->SetDataType(framework::proto::VarType::FP32);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW("Not support variable type %s",
|
||||||
|
framework::ToTypeName(inner_var.Type()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProgramDescTracer::Reset() {
|
||||||
|
ops_.clear();
|
||||||
|
vars_.clear();
|
||||||
|
feed_vars_.clear();
|
||||||
|
fetch_vars_.clear();
|
||||||
|
name_prefix_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace imperative
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <forward_list>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/program_desc.h"
|
||||||
|
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
|
||||||
|
#include "paddle/fluid/imperative/layer.h"
|
||||||
|
#include "paddle/fluid/imperative/type_defs.h"
|
||||||
|
#include "paddle/fluid/platform/macros.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace imperative {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
class ProgramDescTracer {
|
||||||
|
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
|
||||||
|
|
||||||
|
public:
|
||||||
|
ProgramDescTracer() = default;
|
||||||
|
|
||||||
|
void SetNamePrefix(const std::string &name_prefix);
|
||||||
|
|
||||||
|
void SetFeedVars(const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
||||||
|
std::vector<std::string> feed_names);
|
||||||
|
|
||||||
|
void SetFetchVars(const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
||||||
|
std::vector<std::string> fetch_names);
|
||||||
|
|
||||||
|
void InsertOp(const std::string &type, const NameVarBaseMap &inputs,
|
||||||
|
const NameVarBaseMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs);
|
||||||
|
|
||||||
|
std::unique_ptr<framework::ProgramDesc> CreateProgramDesc() const;
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<OpDescMeta>> ops_;
|
||||||
|
|
||||||
|
std::map<std::weak_ptr<VarBase>,
|
||||||
|
std::pair<size_t, std::unique_ptr<framework::VarDesc>>,
|
||||||
|
std::owner_less<std::weak_ptr<VarBase>>>
|
||||||
|
vars_;
|
||||||
|
|
||||||
|
// The following fields are used to polish the converted ProgramDesc
|
||||||
|
std::map<std::weak_ptr<VarBase>, std::string,
|
||||||
|
std::owner_less<std::weak_ptr<VarBase>>>
|
||||||
|
feed_vars_;
|
||||||
|
|
||||||
|
std::map<std::weak_ptr<VarBase>, std::string,
|
||||||
|
std::owner_less<std::weak_ptr<VarBase>>>
|
||||||
|
fetch_vars_;
|
||||||
|
|
||||||
|
std::string name_prefix_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace imperative
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = ['trace']
|
||||||
|
|
||||||
|
from . import layers
|
||||||
|
from .base import program_desc_tracing_guard
|
||||||
|
from .layers import Layer
|
||||||
|
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard
|
||||||
|
|
||||||
|
|
||||||
|
def create_program_from_desc(program_desc):
|
||||||
|
program = Program()
|
||||||
|
program.desc = program_desc
|
||||||
|
program.blocks = [Block(program, 0)]
|
||||||
|
program._sync_with_cpp()
|
||||||
|
return program
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_vars(inputs, result_list):
|
||||||
|
if isinstance(inputs, Variable):
|
||||||
|
result_list.append(inputs._ivar)
|
||||||
|
|
||||||
|
if isinstance(inputs, (list, tuple)):
|
||||||
|
for var in inputs:
|
||||||
|
_extract_vars(var, result_list)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_vars(inputs):
|
||||||
|
result_list = []
|
||||||
|
_extract_vars(inputs, result_list)
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
|
||||||
|
@dygraph_only
|
||||||
|
def trace(module, inputs, feed_names=None, fetch_names=None):
|
||||||
|
assert isinstance(module, Layer)
|
||||||
|
|
||||||
|
if not isinstance(inputs, (list, tuple)):
|
||||||
|
inputs = [inputs]
|
||||||
|
|
||||||
|
if feed_names is None:
|
||||||
|
feed_names = []
|
||||||
|
|
||||||
|
if fetch_names is None:
|
||||||
|
fetch_names = []
|
||||||
|
|
||||||
|
tracer = _dygraph_tracer()._get_program_desc_tracer()
|
||||||
|
|
||||||
|
var_list = extract_vars(inputs)
|
||||||
|
tracer.set_feed_vars(var_list, feed_names)
|
||||||
|
|
||||||
|
with program_desc_tracing_guard(True):
|
||||||
|
original_outputs = module.__call__(*inputs)
|
||||||
|
if not isinstance(original_outputs, (list, tuple)):
|
||||||
|
outputs = [original_outputs]
|
||||||
|
else:
|
||||||
|
outputs = original_outputs
|
||||||
|
out_vars = [var._ivar for var in outputs]
|
||||||
|
|
||||||
|
tracer.set_fetch_vars(out_vars, fetch_names)
|
||||||
|
tracer.set_name_prefix('t_')
|
||||||
|
|
||||||
|
program_desc = tracer.create_program_desc()
|
||||||
|
tracer.reset()
|
||||||
|
|
||||||
|
with _dygraph_guard(None):
|
||||||
|
program = create_program_from_desc(program_desc)
|
||||||
|
|
||||||
|
return original_outputs, program
|
@ -0,0 +1,190 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.fluid.framework import _dygraph_guard
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.framework import Variable
|
||||||
|
import paddle.fluid.dygraph.jit as jit
|
||||||
|
from paddle.fluid.dygraph.jit import extract_vars
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
__all__ = ['DyGraphProgramDescTracerTestHelper', ]
|
||||||
|
|
||||||
|
|
||||||
|
def is_equal_program(prog1, prog2):
|
||||||
|
with _dygraph_guard(None):
|
||||||
|
return _is_equal_program(prog1, prog2)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_equal_program(prog1, prog2):
|
||||||
|
block_num = prog1.num_blocks
|
||||||
|
if block_num != prog2.num_blocks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for block_id in range(block_num):
|
||||||
|
block1 = prog1.block(block_id)
|
||||||
|
block2 = prog2.block(block_id)
|
||||||
|
|
||||||
|
if len(block1.ops) != len(block2.ops):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(block1.vars) != len(block2.vars):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for op1, op2 in zip(block1.ops, block2.ops):
|
||||||
|
if op1.input_arg_names != op2.input_arg_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if op1.output_arg_names != op2.output_arg_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
attr1 = op1.all_attrs()
|
||||||
|
attr2 = op2.all_attrs()
|
||||||
|
|
||||||
|
if len(attr1) != len(attr2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for key1, value1 in attr1.items():
|
||||||
|
if key1 not in attr2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if value1 != attr2.get(key1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for var1 in block1.vars.values():
|
||||||
|
if var1.name not in block2.vars:
|
||||||
|
return False
|
||||||
|
|
||||||
|
var2 = block2.vars.get(var1.name)
|
||||||
|
if var1.name != var2.name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if var1.type != var2.type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if var1.dtype != var2.dtype:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if var1.lod_level != var2.lod_level:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if var1.persistable != var2.persistable:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_dygraph_vars_to_scope(model_path, scope, place):
|
||||||
|
def load_dict_to_scope(scope, dictionary):
|
||||||
|
if scope is None:
|
||||||
|
scope = fluid.global_scope()
|
||||||
|
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
dst_t = scope.var(k).get_tensor()
|
||||||
|
src_t = v.value().get_tensor()
|
||||||
|
dst_t.set(np.array(src_t), place)
|
||||||
|
dst_t.set_lod(src_t.lod())
|
||||||
|
|
||||||
|
param_dict, opti_dict = fluid.load_dygraph(model_path)
|
||||||
|
if param_dict:
|
||||||
|
load_dict_to_scope(scope, param_dict)
|
||||||
|
|
||||||
|
if opti_dict:
|
||||||
|
load_dict_to_scope(scope, opti_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DyGraphProgramDescTracerTestHelper(object):
|
||||||
|
def __init__(self,
|
||||||
|
module,
|
||||||
|
unittest_obj,
|
||||||
|
model_path=None,
|
||||||
|
scope=None,
|
||||||
|
place=None):
|
||||||
|
self.module = module
|
||||||
|
self.unittest_obj = unittest_obj
|
||||||
|
self.scope = fluid.Scope() if scope is None else scope
|
||||||
|
|
||||||
|
self.model_path = model_path
|
||||||
|
if model_path is None:
|
||||||
|
millis = int(round(time.time() * 1000))
|
||||||
|
self.model_path = "id_{}_{}".format(id(module), millis)
|
||||||
|
|
||||||
|
self.place = place
|
||||||
|
if place is None:
|
||||||
|
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||||
|
) else fluid.CPUPlace()
|
||||||
|
|
||||||
|
self.program = None
|
||||||
|
|
||||||
|
self.executor = fluid.Executor(self.place)
|
||||||
|
|
||||||
|
def _remove_model_path(self):
|
||||||
|
if os.path.exists(self.model_path + ".pdparams"):
|
||||||
|
os.remove(self.model_path + ".pdparams")
|
||||||
|
|
||||||
|
if os.path.exists(self.model_path + ".pdopt"):
|
||||||
|
os.remove(self.model_path + ".pdopt")
|
||||||
|
|
||||||
|
def _run_static_graph(self, inputs, feed_names, fetch_names):
|
||||||
|
var_list = extract_vars(inputs)
|
||||||
|
assert len(var_list) == len(feed_names)
|
||||||
|
|
||||||
|
feed_dict = {}
|
||||||
|
for name, var in zip(feed_names, var_list):
|
||||||
|
feed_dict[name] = np.array(var.value().get_tensor())
|
||||||
|
|
||||||
|
with fluid.scope_guard(self.scope):
|
||||||
|
with _dygraph_guard(None):
|
||||||
|
return self.executor.run(self.program,
|
||||||
|
feed=feed_dict,
|
||||||
|
fetch_list=fetch_names)
|
||||||
|
|
||||||
|
def run(self, inputs, feed_names, fetch_names):
|
||||||
|
out_dygraph, program = jit.trace(
|
||||||
|
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
|
||||||
|
|
||||||
|
if self.program is not None:
|
||||||
|
self.unittest_obj.assertTrue(
|
||||||
|
is_equal_program(self.program, program))
|
||||||
|
|
||||||
|
self.program = program
|
||||||
|
|
||||||
|
fluid.save_dygraph(self.module.state_dict(), self.model_path)
|
||||||
|
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
|
||||||
|
|
||||||
|
self._remove_model_path()
|
||||||
|
|
||||||
|
out_static_graph = self._run_static_graph(inputs, feed_names,
|
||||||
|
fetch_names)
|
||||||
|
|
||||||
|
if not isinstance(out_dygraph, (list, tuple)):
|
||||||
|
assert len(out_static_graph) == 1
|
||||||
|
out_static_graph = out_static_graph[0]
|
||||||
|
|
||||||
|
return out_dygraph, out_static_graph
|
||||||
|
|
||||||
|
def assertEachVar(self, out_dygraph, out_static_graph, func=None):
|
||||||
|
if func is None:
|
||||||
|
func = lambda x, y: np.array_equal(x, y)
|
||||||
|
|
||||||
|
if not isinstance(out_dygraph, (list, tuple)):
|
||||||
|
out_dygraph = [out_dygraph]
|
||||||
|
|
||||||
|
if not isinstance(out_static_graph, (list, tuple)):
|
||||||
|
out_static_graph = [out_static_graph]
|
||||||
|
|
||||||
|
for v1, v2 in zip(out_dygraph, out_static_graph):
|
||||||
|
self.unittest_obj.assertTrue(func(v1.numpy(), v2))
|
Loading…
Reference in new issue