Use heap variables

helinwang-patch-1
Yu Yang 7 years ago
parent 222763296f
commit 9af870854e

@ -16,11 +16,17 @@
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
namespace details {
struct OpHandleBase {
class OpHandleBase {
private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
public:
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
@ -31,6 +37,8 @@ struct OpHandleBase {
std::unordered_map<int, cudaEvent_t> events_;
#endif
OpHandleBase() {}
std::string DebugString() const;
virtual std::string Name() const = 0;

@ -67,7 +67,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
// Step 2. Insert FetchOps
std::vector<FetchOpHandle> fetch_ops;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size());
@ -84,9 +84,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(&fetch_data, i, &local_scopes_);
details::FetchOpHandle *op = &fetch_ops.back();
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
fetch_ops.emplace_back(op);
// FIXME: Use new device context
for (auto &p : places_) {
@ -138,7 +138,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &op : pending_ops) {
VLOG(10) << op.first->DebugString();
}
// keep waiting the ready variables
continue;
}

@ -231,6 +231,9 @@ class TestMNIST(TestParallelExecutorBase):
class TestResnet(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
import os
if os.path.exists('./flowers.recordio'):
return
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(flowers.train(), batch_size=4)
feeder = fluid.DataFeeder(

Loading…
Cancel
Save