[Dygraph to static graph]JIT/Trace (#20775)
* jit/trace 1st version, test=develop * add more unittests, test=developyaoxuefeng
parent
6e6eab07e8
commit
8ff6b289bd
@ -0,0 +1,2 @@
|
||||
cc_library(op_desc_meta SRCS op_desc_meta.cc DEPS proto_desc layer)
|
||||
cc_library(program_desc_tracer SRCS program_desc_tracer.cc DEPS op_desc_meta)
|
@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
|
||||
#include "paddle/fluid/framework/op_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
namespace jit {
|
||||
|
||||
OpDescMeta::OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
|
||||
const NameVarBaseMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: type_(type), attrs_(attrs) {
|
||||
auto *proto = framework::OpInfoMap::Instance().GetNullable(type_);
|
||||
if (proto && proto->Checker()) {
|
||||
proto->Checker()->Check(&attrs_);
|
||||
}
|
||||
|
||||
for (auto &pair : inputs) {
|
||||
inputs_[pair.first].assign(pair.second.begin(), pair.second.end());
|
||||
}
|
||||
|
||||
for (auto &pair : outputs) {
|
||||
outputs_[pair.first].assign(pair.second.begin(), pair.second.end());
|
||||
}
|
||||
}
|
||||
|
||||
const std::string &OpDescMeta::Type() const { return type_; }
|
||||
|
||||
const WeakNameVarBaseMap &OpDescMeta::Inputs() const { return inputs_; }
|
||||
|
||||
const WeakNameVarBaseMap &OpDescMeta::Outputs() const { return outputs_; }
|
||||
|
||||
const framework::AttributeMap &OpDescMeta::Attrs() const { return attrs_; }
|
||||
|
||||
} // namespace jit
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
namespace jit {
|
||||
|
||||
class OpDescMeta {
|
||||
public:
|
||||
OpDescMeta(const std::string &type, const NameVarBaseMap &inputs,
|
||||
const NameVarBaseMap &outputs,
|
||||
const framework::AttributeMap &attrs);
|
||||
|
||||
const std::string &Type() const;
|
||||
|
||||
const WeakNameVarBaseMap &Inputs() const;
|
||||
|
||||
const WeakNameVarBaseMap &Outputs() const;
|
||||
|
||||
const framework::AttributeMap &Attrs() const;
|
||||
|
||||
private:
|
||||
std::string type_;
|
||||
WeakNameVarBaseMap inputs_;
|
||||
WeakNameVarBaseMap outputs_;
|
||||
framework::AttributeMap attrs_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,235 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
namespace jit {
|
||||
|
||||
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) {
|
||||
name_prefix_ = name_prefix;
|
||||
}
|
||||
|
||||
void ProgramDescTracer::SetFeedVars(
|
||||
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
||||
std::vector<std::string> feed_names) {
|
||||
feed_vars_.clear();
|
||||
|
||||
if (feed_names.empty()) {
|
||||
feed_names.reserve(feed_vars.size());
|
||||
for (auto &var : feed_vars) {
|
||||
feed_names.emplace_back(var->Name());
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(),
|
||||
"The feeded variable names number must be equal to the "
|
||||
"feeded variable number");
|
||||
|
||||
for (size_t i = 0; i < feed_names.size(); ++i) {
|
||||
feed_vars_[feed_vars[i]] = feed_names[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ProgramDescTracer::SetFetchVars(
|
||||
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
||||
std::vector<std::string> fetch_names) {
|
||||
fetch_vars_.clear();
|
||||
|
||||
if (fetch_names.empty()) {
|
||||
fetch_names.reserve(fetch_vars.size());
|
||||
for (auto &var : fetch_vars) {
|
||||
fetch_names.emplace_back(var->Name());
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(),
|
||||
"The fetched variable names number must be equal to the "
|
||||
"fetched variable number");
|
||||
for (size_t i = 0; i < fetch_names.size(); ++i) {
|
||||
fetch_vars_[fetch_vars[i]] = fetch_names[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ProgramDescTracer::InsertOp(const std::string &type,
|
||||
const NameVarBaseMap &inputs,
|
||||
const NameVarBaseMap &outputs,
|
||||
const framework::AttributeMap &attrs) {
|
||||
ops_.emplace_back(new OpDescMeta(type, inputs, outputs, attrs));
|
||||
auto &new_op = ops_.back();
|
||||
for (auto &pair : new_op->Inputs()) {
|
||||
for (auto &var : pair.second) {
|
||||
InsertVarIfNotExist(var.lock());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &pair : new_op->Outputs()) {
|
||||
for (auto &var : pair.second) {
|
||||
InsertVarIfNotExist(var.lock());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
||||
const {
|
||||
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
|
||||
auto *block = prog->MutableBlock(0);
|
||||
|
||||
size_t var_num = vars_.size();
|
||||
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
|
||||
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
|
||||
var_desc_to_var_base;
|
||||
|
||||
for (auto &pair : vars_) {
|
||||
size_t var_id = pair.second.first;
|
||||
PADDLE_ENFORCE_LT(var_id, var_num);
|
||||
var_descs[var_id] = pair.second.second.get();
|
||||
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
|
||||
var_desc_to_var_base[var_descs[var_id]] = pair.first;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> existing_var_names;
|
||||
for (auto *var_desc : var_descs) {
|
||||
if (var_desc->Persistable()) {
|
||||
existing_var_names.insert(var_desc->Name());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &pair : feed_vars_) {
|
||||
existing_var_names.insert(pair.second);
|
||||
}
|
||||
|
||||
for (auto &pair : fetch_vars_) {
|
||||
existing_var_names.insert(pair.second);
|
||||
}
|
||||
|
||||
size_t counter = 0;
|
||||
auto generate_unique_name = [&]() -> std::string {
|
||||
do {
|
||||
auto name = name_prefix_ + std::to_string(counter++);
|
||||
if (existing_var_names.count(name) == 0) {
|
||||
existing_var_names.insert(name);
|
||||
return name;
|
||||
}
|
||||
} while (counter > 0);
|
||||
PADDLE_THROW("Too many vars in the program");
|
||||
};
|
||||
|
||||
std::map<std::weak_ptr<VarBase>, std::string,
|
||||
std::owner_less<std::weak_ptr<VarBase>>>
|
||||
var_to_name;
|
||||
for (auto *var_desc : var_descs) {
|
||||
auto var_name = var_desc->Name();
|
||||
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
|
||||
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
|
||||
if (feed_vars_.count(var_base) > 0) {
|
||||
var_name = feed_vars_.at(var_base);
|
||||
} else if (fetch_vars_.count(var_base) > 0) {
|
||||
var_name = fetch_vars_.at(var_base);
|
||||
} else if (!var_desc->Persistable()) {
|
||||
var_name = generate_unique_name();
|
||||
}
|
||||
|
||||
auto *new_var_desc = block->Var(var_name);
|
||||
*new_var_desc = *var_desc;
|
||||
new_var_desc->SetName(std::move(var_name));
|
||||
var_to_name[var_base] = new_var_desc->Name();
|
||||
}
|
||||
|
||||
for (auto &op : ops_) {
|
||||
auto *op_desc = block->AppendOp();
|
||||
op_desc->SetType(op->Type());
|
||||
op_desc->SetAttrMap(op->Attrs());
|
||||
|
||||
for (auto &pair : op->Inputs()) {
|
||||
std::vector<std::string> names;
|
||||
names.reserve(pair.second.size());
|
||||
for (auto &var : pair.second) {
|
||||
auto iter = var_to_name.find(var);
|
||||
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
||||
"Cannot find input variable");
|
||||
names.emplace_back(iter->second);
|
||||
}
|
||||
|
||||
op_desc->SetInput(pair.first, std::move(names));
|
||||
}
|
||||
|
||||
for (auto &pair : op->Outputs()) {
|
||||
std::vector<std::string> names;
|
||||
names.reserve(pair.second.size());
|
||||
for (auto &var : pair.second) {
|
||||
auto iter = var_to_name.find(var);
|
||||
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
||||
"Cannot find output variable");
|
||||
names.emplace_back(iter->second);
|
||||
}
|
||||
|
||||
op_desc->SetOutput(pair.first, std::move(names));
|
||||
}
|
||||
}
|
||||
|
||||
prog->Flush();
|
||||
return prog;
|
||||
}
|
||||
|
||||
void ProgramDescTracer::InsertVarIfNotExist(
|
||||
const std::shared_ptr<VarBase> &new_var) {
|
||||
PADDLE_ENFORCE_NOT_NULL(new_var);
|
||||
if (vars_.count(new_var) != 0) return;
|
||||
|
||||
size_t var_id = vars_.size();
|
||||
auto new_var_desc = new framework::VarDesc("");
|
||||
vars_[new_var] =
|
||||
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
|
||||
|
||||
if (new_var->Persistable()) {
|
||||
new_var_desc->SetName(new_var->Name());
|
||||
new_var_desc->SetPersistable(true);
|
||||
} else {
|
||||
new_var_desc->SetPersistable(false);
|
||||
}
|
||||
|
||||
const auto &inner_var = new_var->Var();
|
||||
PADDLE_ENFORCE_EQ(inner_var.IsInitialized(), true);
|
||||
if (inner_var.IsType<framework::LoDTensor>()) {
|
||||
const auto &tensor = inner_var.Get<framework::LoDTensor>();
|
||||
new_var_desc->SetType(framework::proto::VarType::LOD_TENSOR);
|
||||
new_var_desc->SetShape(framework::vectorize<int64_t>(tensor.dims()));
|
||||
new_var_desc->SetLoDLevel(tensor.lod().size());
|
||||
if (tensor.IsInitialized()) {
|
||||
new_var_desc->SetDataType(tensor.type());
|
||||
} else {
|
||||
new_var_desc->SetDataType(framework::proto::VarType::FP32);
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW("Not support variable type %s",
|
||||
framework::ToTypeName(inner_var.Type()));
|
||||
}
|
||||
}
|
||||
|
||||
void ProgramDescTracer::Reset() {
|
||||
ops_.clear();
|
||||
vars_.clear();
|
||||
feed_vars_.clear();
|
||||
fetch_vars_.clear();
|
||||
name_prefix_.clear();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <forward_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/imperative/jit/op_desc_meta.h"
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
#include "paddle/fluid/imperative/type_defs.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace imperative {
|
||||
namespace jit {
|
||||
|
||||
class ProgramDescTracer {
|
||||
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
|
||||
|
||||
public:
|
||||
ProgramDescTracer() = default;
|
||||
|
||||
void SetNamePrefix(const std::string &name_prefix);
|
||||
|
||||
void SetFeedVars(const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
||||
std::vector<std::string> feed_names);
|
||||
|
||||
void SetFetchVars(const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
||||
std::vector<std::string> fetch_names);
|
||||
|
||||
void InsertOp(const std::string &type, const NameVarBaseMap &inputs,
|
||||
const NameVarBaseMap &outputs,
|
||||
const framework::AttributeMap &attrs);
|
||||
|
||||
std::unique_ptr<framework::ProgramDesc> CreateProgramDesc() const;
|
||||
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
|
||||
|
||||
std::vector<std::unique_ptr<OpDescMeta>> ops_;
|
||||
|
||||
std::map<std::weak_ptr<VarBase>,
|
||||
std::pair<size_t, std::unique_ptr<framework::VarDesc>>,
|
||||
std::owner_less<std::weak_ptr<VarBase>>>
|
||||
vars_;
|
||||
|
||||
// The following fields are used to polish the converted ProgramDesc
|
||||
std::map<std::weak_ptr<VarBase>, std::string,
|
||||
std::owner_less<std::weak_ptr<VarBase>>>
|
||||
feed_vars_;
|
||||
|
||||
std::map<std::weak_ptr<VarBase>, std::string,
|
||||
std::owner_less<std::weak_ptr<VarBase>>>
|
||||
fetch_vars_;
|
||||
|
||||
std::string name_prefix_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace imperative
|
||||
} // namespace paddle
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['trace']
|
||||
|
||||
from . import layers
|
||||
from .base import program_desc_tracing_guard
|
||||
from .layers import Layer
|
||||
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard
|
||||
|
||||
|
||||
def create_program_from_desc(program_desc):
|
||||
program = Program()
|
||||
program.desc = program_desc
|
||||
program.blocks = [Block(program, 0)]
|
||||
program._sync_with_cpp()
|
||||
return program
|
||||
|
||||
|
||||
def _extract_vars(inputs, result_list):
|
||||
if isinstance(inputs, Variable):
|
||||
result_list.append(inputs._ivar)
|
||||
|
||||
if isinstance(inputs, (list, tuple)):
|
||||
for var in inputs:
|
||||
_extract_vars(var, result_list)
|
||||
|
||||
|
||||
def extract_vars(inputs):
|
||||
result_list = []
|
||||
_extract_vars(inputs, result_list)
|
||||
return result_list
|
||||
|
||||
|
||||
@dygraph_only
|
||||
def trace(module, inputs, feed_names=None, fetch_names=None):
|
||||
assert isinstance(module, Layer)
|
||||
|
||||
if not isinstance(inputs, (list, tuple)):
|
||||
inputs = [inputs]
|
||||
|
||||
if feed_names is None:
|
||||
feed_names = []
|
||||
|
||||
if fetch_names is None:
|
||||
fetch_names = []
|
||||
|
||||
tracer = _dygraph_tracer()._get_program_desc_tracer()
|
||||
|
||||
var_list = extract_vars(inputs)
|
||||
tracer.set_feed_vars(var_list, feed_names)
|
||||
|
||||
with program_desc_tracing_guard(True):
|
||||
original_outputs = module.__call__(*inputs)
|
||||
if not isinstance(original_outputs, (list, tuple)):
|
||||
outputs = [original_outputs]
|
||||
else:
|
||||
outputs = original_outputs
|
||||
out_vars = [var._ivar for var in outputs]
|
||||
|
||||
tracer.set_fetch_vars(out_vars, fetch_names)
|
||||
tracer.set_name_prefix('t_')
|
||||
|
||||
program_desc = tracer.create_program_desc()
|
||||
tracer.reset()
|
||||
|
||||
with _dygraph_guard(None):
|
||||
program = create_program_from_desc(program_desc)
|
||||
|
||||
return original_outputs, program
|
@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.fluid.framework import _dygraph_guard
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.framework import Variable
|
||||
import paddle.fluid.dygraph.jit as jit
|
||||
from paddle.fluid.dygraph.jit import extract_vars
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
|
||||
__all__ = ['DyGraphProgramDescTracerTestHelper', ]
|
||||
|
||||
|
||||
def is_equal_program(prog1, prog2):
|
||||
with _dygraph_guard(None):
|
||||
return _is_equal_program(prog1, prog2)
|
||||
|
||||
|
||||
def _is_equal_program(prog1, prog2):
|
||||
block_num = prog1.num_blocks
|
||||
if block_num != prog2.num_blocks:
|
||||
return False
|
||||
|
||||
for block_id in range(block_num):
|
||||
block1 = prog1.block(block_id)
|
||||
block2 = prog2.block(block_id)
|
||||
|
||||
if len(block1.ops) != len(block2.ops):
|
||||
return False
|
||||
|
||||
if len(block1.vars) != len(block2.vars):
|
||||
return False
|
||||
|
||||
for op1, op2 in zip(block1.ops, block2.ops):
|
||||
if op1.input_arg_names != op2.input_arg_names:
|
||||
return False
|
||||
|
||||
if op1.output_arg_names != op2.output_arg_names:
|
||||
return False
|
||||
|
||||
attr1 = op1.all_attrs()
|
||||
attr2 = op2.all_attrs()
|
||||
|
||||
if len(attr1) != len(attr2):
|
||||
return False
|
||||
|
||||
for key1, value1 in attr1.items():
|
||||
if key1 not in attr2:
|
||||
return False
|
||||
|
||||
if value1 != attr2.get(key1):
|
||||
return False
|
||||
|
||||
for var1 in block1.vars.values():
|
||||
if var1.name not in block2.vars:
|
||||
return False
|
||||
|
||||
var2 = block2.vars.get(var1.name)
|
||||
if var1.name != var2.name:
|
||||
return False
|
||||
|
||||
if var1.type != var2.type:
|
||||
return False
|
||||
|
||||
if var1.dtype != var2.dtype:
|
||||
return False
|
||||
|
||||
if var1.lod_level != var2.lod_level:
|
||||
return False
|
||||
|
||||
if var1.persistable != var2.persistable:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_dygraph_vars_to_scope(model_path, scope, place):
|
||||
def load_dict_to_scope(scope, dictionary):
|
||||
if scope is None:
|
||||
scope = fluid.global_scope()
|
||||
|
||||
for k, v in dictionary.items():
|
||||
dst_t = scope.var(k).get_tensor()
|
||||
src_t = v.value().get_tensor()
|
||||
dst_t.set(np.array(src_t), place)
|
||||
dst_t.set_lod(src_t.lod())
|
||||
|
||||
param_dict, opti_dict = fluid.load_dygraph(model_path)
|
||||
if param_dict:
|
||||
load_dict_to_scope(scope, param_dict)
|
||||
|
||||
if opti_dict:
|
||||
load_dict_to_scope(scope, opti_dict)
|
||||
|
||||
|
||||
class DyGraphProgramDescTracerTestHelper(object):
|
||||
def __init__(self,
|
||||
module,
|
||||
unittest_obj,
|
||||
model_path=None,
|
||||
scope=None,
|
||||
place=None):
|
||||
self.module = module
|
||||
self.unittest_obj = unittest_obj
|
||||
self.scope = fluid.Scope() if scope is None else scope
|
||||
|
||||
self.model_path = model_path
|
||||
if model_path is None:
|
||||
millis = int(round(time.time() * 1000))
|
||||
self.model_path = "id_{}_{}".format(id(module), millis)
|
||||
|
||||
self.place = place
|
||||
if place is None:
|
||||
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
|
||||
self.program = None
|
||||
|
||||
self.executor = fluid.Executor(self.place)
|
||||
|
||||
def _remove_model_path(self):
|
||||
if os.path.exists(self.model_path + ".pdparams"):
|
||||
os.remove(self.model_path + ".pdparams")
|
||||
|
||||
if os.path.exists(self.model_path + ".pdopt"):
|
||||
os.remove(self.model_path + ".pdopt")
|
||||
|
||||
def _run_static_graph(self, inputs, feed_names, fetch_names):
|
||||
var_list = extract_vars(inputs)
|
||||
assert len(var_list) == len(feed_names)
|
||||
|
||||
feed_dict = {}
|
||||
for name, var in zip(feed_names, var_list):
|
||||
feed_dict[name] = np.array(var.value().get_tensor())
|
||||
|
||||
with fluid.scope_guard(self.scope):
|
||||
with _dygraph_guard(None):
|
||||
return self.executor.run(self.program,
|
||||
feed=feed_dict,
|
||||
fetch_list=fetch_names)
|
||||
|
||||
def run(self, inputs, feed_names, fetch_names):
|
||||
out_dygraph, program = jit.trace(
|
||||
self.module, inputs, feed_names=feed_names, fetch_names=fetch_names)
|
||||
|
||||
if self.program is not None:
|
||||
self.unittest_obj.assertTrue(
|
||||
is_equal_program(self.program, program))
|
||||
|
||||
self.program = program
|
||||
|
||||
fluid.save_dygraph(self.module.state_dict(), self.model_path)
|
||||
load_dygraph_vars_to_scope(self.model_path, self.scope, self.place)
|
||||
|
||||
self._remove_model_path()
|
||||
|
||||
out_static_graph = self._run_static_graph(inputs, feed_names,
|
||||
fetch_names)
|
||||
|
||||
if not isinstance(out_dygraph, (list, tuple)):
|
||||
assert len(out_static_graph) == 1
|
||||
out_static_graph = out_static_graph[0]
|
||||
|
||||
return out_dygraph, out_static_graph
|
||||
|
||||
def assertEachVar(self, out_dygraph, out_static_graph, func=None):
|
||||
if func is None:
|
||||
func = lambda x, y: np.array_equal(x, y)
|
||||
|
||||
if not isinstance(out_dygraph, (list, tuple)):
|
||||
out_dygraph = [out_dygraph]
|
||||
|
||||
if not isinstance(out_static_graph, (list, tuple)):
|
||||
out_static_graph = [out_static_graph]
|
||||
|
||||
for v1, v2 in zip(out_dygraph, out_static_graph):
|
||||
self.unittest_obj.assertTrue(func(v1.numpy(), v2))
|
Loading…
Reference in new issue