parent
5670530ca7
commit
a93a9eef8f
@ -1,140 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/details/var_handle.h"
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class EarlyDeleteOpHandle : public OpHandleBase {
|
||||
public:
|
||||
EarlyDeleteOpHandle(ir::Node* node, const Scope* scope,
|
||||
const platform::Place& place,
|
||||
const std::vector<std::string>& names,
|
||||
GarbageCollector* gc)
|
||||
: OpHandleBase(node),
|
||||
scope_(scope),
|
||||
place_(place),
|
||||
names_(names),
|
||||
gc_(gc) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(place);
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
~EarlyDeleteOpHandle() {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (IsStreamGarabageCollector()) {
|
||||
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
|
||||
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device));
|
||||
PADDLE_ENFORCE(cudaEventDestroy(event_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Name() const override { return "early_delete"; }
|
||||
|
||||
protected:
|
||||
void RunImpl() override {
|
||||
std::vector<std::shared_ptr<memory::Allocation>> tensors;
|
||||
auto* local_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope*>();
|
||||
for (auto& var_name : names_) {
|
||||
auto* var = local_scope->FindVar(var_name);
|
||||
PADDLE_ENFORCE(var != nullptr,
|
||||
string::Sprintf("Local Scope not has var %s", var_name));
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
tensors.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
tensors.emplace_back(var->GetMutable<SelectedRows>()
|
||||
->mutable_value()
|
||||
->MoveMemoryHolder());
|
||||
} else if (var->IsType<LoDTensorArray>()) {
|
||||
LoDTensorArray* tensor_array = var->GetMutable<LoDTensorArray>();
|
||||
for (auto& tensor : *tensor_array) {
|
||||
tensors.emplace_back(tensor.MoveMemoryHolder());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!tensors.empty()) {
|
||||
ClearTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void ClearTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
if (platform::is_cpu_place(place_)) {
|
||||
ClearCPUTensors(tensors);
|
||||
} else {
|
||||
ClearGPUTensors(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearCPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
auto* gc = dynamic_cast<CPUGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
gc->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearGPUTensors(
|
||||
const std::vector<std::shared_ptr<memory::Allocation>>& tensors) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto* gc = dynamic_cast<StreamGarbageCollector*>(gc_);
|
||||
if (gc != nullptr) {
|
||||
auto compute_stream = dev_ctx_->stream();
|
||||
auto callback_stream = gc->stream();
|
||||
auto callback_func = [=]() {
|
||||
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
|
||||
};
|
||||
gc_->Add(tensors, callback_func);
|
||||
} else {
|
||||
gc_->Add(tensors);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsStreamGarabageCollector() const {
|
||||
return dynamic_cast<const StreamGarbageCollector*>(gc_) != nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
const Scope* scope_;
|
||||
const platform::Place place_;
|
||||
std::vector<std::string> names_;
|
||||
GarbageCollector* gc_;
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDADeviceContext* dev_ctx_;
|
||||
cudaEvent_t event_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,185 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/executor_gc_helper.h"
|
||||
#include <deque>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
struct OpInOutInfo {
|
||||
public:
|
||||
void Build(const OperatorBase *op) {
|
||||
is_built_ = true;
|
||||
auto &inferer = op->Info().NoNeedBufferVarsInferer();
|
||||
if (inferer) {
|
||||
no_need_buffer_ins_ = inferer(op->Inputs(), op->Outputs(), op->Attrs());
|
||||
|
||||
if (no_need_buffer_ins_.empty()) return;
|
||||
|
||||
for (auto &in_name_pair : op->Inputs()) {
|
||||
if (no_need_buffer_ins_.count(in_name_pair.first) != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &in_arg_name : in_name_pair.second) {
|
||||
other_args_set_.insert(in_arg_name);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &out_name_pair : op->Outputs()) {
|
||||
for (auto &out_arg_name : out_name_pair.second) {
|
||||
other_args_set_.insert(out_arg_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsBuilt() const { return is_built_; }
|
||||
|
||||
bool IsInArgBufferNeeded(const std::string &in_arg_name) const {
|
||||
return no_need_buffer_ins_.empty() ||
|
||||
other_args_set_.count(in_arg_name) != 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> no_need_buffer_ins_;
|
||||
std::unordered_set<std::string> other_args_set_;
|
||||
bool is_built_{false};
|
||||
};
|
||||
|
||||
static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block,
|
||||
const std::unordered_set<std::string> &skip_vars) {
|
||||
if (skip_vars.count(name) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto *var_desc = block.FindVar(name);
|
||||
if (var_desc == nullptr || var_desc->Persistable()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type = var_desc->Proto()->type().type();
|
||||
|
||||
return type == proto::VarType::LOD_TENSOR ||
|
||||
type == proto::VarType::SELECTED_ROWS ||
|
||||
type == proto::VarType::LOD_TENSOR_ARRAY;
|
||||
}
|
||||
|
||||
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
|
||||
const BlockDesc &block,
|
||||
const std::vector<std::unique_ptr<OperatorBase>> &ops,
|
||||
const std::vector<std::string> &skip_var_list) {
|
||||
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
|
||||
skip_var_list.end());
|
||||
|
||||
std::unordered_map<std::string, size_t> var_op_idx_map;
|
||||
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
auto *op = ops[i].get();
|
||||
|
||||
OpInOutInfo info;
|
||||
for (auto &name_pair : op->Inputs()) {
|
||||
for (auto &name : name_pair.second) {
|
||||
if (!VarCanBeDeleted(name, block, skip_vars)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// var can be gc-ed
|
||||
if (!info.IsBuilt()) {
|
||||
info.Build(op);
|
||||
}
|
||||
|
||||
if (info.IsInArgBufferNeeded(name)) {
|
||||
var_op_idx_map[name] = i;
|
||||
} else {
|
||||
VLOG(10) << "Skip reference count computing of variable "
|
||||
<< name_pair.first << "(" << name << ") in Operator "
|
||||
<< op->Type();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &name_pair : op->Outputs()) {
|
||||
for (auto &name : name_pair.second) {
|
||||
if (VarCanBeDeleted(name, block, skip_vars)) {
|
||||
var_op_idx_map[name] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<OperatorBase *, std::vector<std::string>> result;
|
||||
for (auto &name_op_idx_pair : var_op_idx_map) {
|
||||
auto &name = name_op_idx_pair.first;
|
||||
size_t op_idx = name_op_idx_pair.second;
|
||||
result[ops[op_idx].get()].emplace_back(name);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void DeleteUnusedTensors(
|
||||
const Scope &scope, OperatorBase *op,
|
||||
const std::unordered_map<OperatorBase *, std::vector<std::string>>
|
||||
&delete_vars_map,
|
||||
GarbageCollector *gc) {
|
||||
auto iter = delete_vars_map.find(op);
|
||||
if (iter == delete_vars_map.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto &delete_vars = iter->second;
|
||||
|
||||
std::deque<std::shared_ptr<memory::Allocation>> garbages;
|
||||
|
||||
for (auto &var_name : delete_vars) {
|
||||
auto *var = scope.FindVar(var_name);
|
||||
if (var == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(2) << "Erase variable " << var_name;
|
||||
if (var->IsType<LoDTensor>()) {
|
||||
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
|
||||
} else if (var->IsType<SelectedRows>()) {
|
||||
garbages.emplace_back(
|
||||
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder());
|
||||
} else if (var->IsType<LoDTensorArray>()) {
|
||||
auto *lod_tensor_arr = var->GetMutable<LoDTensorArray>();
|
||||
for (auto &t : *lod_tensor_arr) {
|
||||
garbages.emplace_back(t.MoveMemoryHolder());
|
||||
}
|
||||
} else {
|
||||
PADDLE_THROW("Type %s of %s is not supported eager deletion",
|
||||
framework::ToTypeName(var->Type()), var_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (!garbages.empty()) {
|
||||
gc->Add(std::move(garbages));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,40 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/garbage_collector.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
|
||||
const BlockDesc &block,
|
||||
const std::vector<std::unique_ptr<OperatorBase>> &ops,
|
||||
const std::vector<std::string> &skip_vars);
|
||||
|
||||
void DeleteUnusedTensors(
|
||||
const Scope &scope, OperatorBase *op,
|
||||
const std::unordered_map<OperatorBase *, std::vector<std::string>>
|
||||
&delete_vars_map,
|
||||
GarbageCollector *gc);
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
// Reserve empty source file
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class NoNeedBufferVarsInference {
|
||||
public:
|
||||
NoNeedBufferVarsInference(const VariableNameMap &inputs,
|
||||
const VariableNameMap &outputs,
|
||||
const AttributeMap &attrs)
|
||||
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
|
||||
|
||||
virtual ~NoNeedBufferVarsInference() = default;
|
||||
|
||||
const VariableNameMap &Inputs() const { return inputs_; }
|
||||
|
||||
const VariableNameMap &Outputs() const { return outputs_; }
|
||||
|
||||
const AttributeMap &Attrs() const { return attrs_; }
|
||||
|
||||
virtual std::unordered_set<std::string> operator()() const = 0;
|
||||
|
||||
private:
|
||||
const VariableNameMap &inputs_;
|
||||
const VariableNameMap &outputs_;
|
||||
const AttributeMap &attrs_;
|
||||
};
|
||||
|
||||
#define DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(class_type, ...) \
|
||||
class class_type : public ::paddle::framework::NoNeedBufferVarsInference { \
|
||||
public: \
|
||||
using ::paddle::framework::NoNeedBufferVarsInference:: \
|
||||
NoNeedBufferVarsInference; \
|
||||
\
|
||||
std::unordered_set<std::string> operator()() const override { \
|
||||
return {__VA_ARGS__}; \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
|
||||
os.environ['CPU_NUM'] = '4'
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import six
|
||||
import unittest
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def simple_fc_net():
|
||||
image = fluid.layers.data(name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
hidden = image
|
||||
for _ in range(4):
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=200,
|
||||
act='tanh',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
loss = fluid.layers.mean(loss)
|
||||
optimizer = fluid.optimizer.Adam(learning_rate=1e-3)
|
||||
optimizer.minimize(loss)
|
||||
return image, label, loss
|
||||
|
||||
|
||||
def get_persistables_and_non_persistables(prog, fetch_list):
|
||||
num_block = prog.num_blocks
|
||||
persitables = set()
|
||||
non_persistables = set()
|
||||
for bid in six.moves.range(num_block):
|
||||
block = prog.block(bid)
|
||||
for _, var in block.vars.items():
|
||||
if var.persistable or var.name in fetch_list:
|
||||
persitables.add(var.name)
|
||||
else:
|
||||
non_persistables.add(var.name)
|
||||
|
||||
return persitables, non_persistables
|
||||
|
||||
|
||||
class TestExecutor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.place = fluid.CPUPlace()
|
||||
|
||||
def test_executor_main(self):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
with fluid.scope_guard(fluid.Scope()):
|
||||
self.executor_main()
|
||||
|
||||
def test_parallel_executor_main(self):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
with fluid.scope_guard(fluid.Scope()):
|
||||
self.pe_main()
|
||||
|
||||
def prepare_feed(self, image, label, dev_cnt=1):
|
||||
batch_size = 32 * dev_cnt
|
||||
image_shape = (batch_size, ) + tuple(image.shape[1:])
|
||||
label_shape = (batch_size, ) + tuple(label.shape[1:])
|
||||
|
||||
image_np = np.random.random(size=image_shape).astype('float32')
|
||||
label_np = np.random.random_integers(
|
||||
low=0, high=9, size=label_shape).astype('int64')
|
||||
|
||||
return image_np, label_np
|
||||
|
||||
def assertScopeVar(self, scope, persitables, non_persistables):
|
||||
for name in persitables:
|
||||
var = scope.find_var(name)
|
||||
self.assertTrue(var is not None)
|
||||
t = var.get_tensor()
|
||||
self.assertTrue(t._is_initialized())
|
||||
|
||||
for name in non_persistables:
|
||||
var = scope.find_var(name)
|
||||
self.assertTrue(var is not None)
|
||||
t = var.get_tensor()
|
||||
if t._is_initialized():
|
||||
print('WARNING: Variable {} is alive'.format(name))
|
||||
self.assertTrue(not t._is_initialized())
|
||||
|
||||
def executor_main(self):
|
||||
image, label, loss = simple_fc_net()
|
||||
loss.persistable = False
|
||||
persistables, non_persistables = get_persistables_and_non_persistables(
|
||||
fluid.default_main_program(), [loss.name])
|
||||
|
||||
exe = fluid.Executor(self.place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
p = fluid.core.Place()
|
||||
p.set_place(self.place)
|
||||
exe = fluid.core.Executor(p)
|
||||
|
||||
for _ in six.moves.range(10):
|
||||
image_np, label_np = self.prepare_feed(image, label)
|
||||
fluid.global_scope().var(image.name).get_tensor().set(image_np,
|
||||
self.place)
|
||||
fluid.global_scope().var(label.name).get_tensor().set(label_np,
|
||||
self.place)
|
||||
# exe.run would not create local scope
|
||||
# so that we can detect whether gc clears temporary variables
|
||||
exe.run(fluid.default_main_program().desc,
|
||||
fluid.global_scope(), 0, False, True, [loss.name])
|
||||
self.assertScopeVar(fluid.global_scope(), persistables,
|
||||
non_persistables)
|
||||
|
||||
def pe_main(self):
|
||||
image, label, loss = simple_fc_net()
|
||||
loss.persistable = False
|
||||
persitables, non_persistables = get_persistables_and_non_persistables(
|
||||
fluid.default_main_program(), [loss.name])
|
||||
|
||||
exe = fluid.Executor(self.place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
exec_strategy = fluid.ExecutionStrategy()
|
||||
exec_strategy.num_iteration_per_drop_scope = 100
|
||||
|
||||
prog = fluid.CompiledProgram(fluid.default_main_program(
|
||||
)).with_data_parallel(
|
||||
loss_name=loss.name, exec_strategy=exec_strategy)
|
||||
|
||||
dev_cnt = fluid.core.get_cuda_device_count() if isinstance(self.place, fluid.CUDAPlace) \
|
||||
else int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
|
||||
|
||||
for idx in six.moves.range(10):
|
||||
image_np, label_np = self.prepare_feed(image, label, dev_cnt)
|
||||
feed = {image.name: image_np, label.name: label_np}
|
||||
|
||||
exe.run(program=prog, feed=feed, fetch_list=[loss])
|
||||
|
||||
local_scopes = prog._local_scopes
|
||||
for scope in local_scopes:
|
||||
kids = scope._kids()
|
||||
self.assertTrue(len(kids) == 1)
|
||||
self.assertScopeVar(kids[0], persistables, non_persistables)
|
||||
|
||||
|
||||
class TestExecutor2(TestExecutor):
|
||||
def setUp(self):
|
||||
self.place = fluid.CPUPlace() if not fluid.core.is_compiled_with_cuda() \
|
||||
else fluid.CUDAPlace(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue