commit
e8e8ad0491
@ -0,0 +1,76 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/executor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
|
||||||
|
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
|
||||||
|
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
|
||||||
|
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
|
||||||
|
: strategy_(std::move(strategy)),
|
||||||
|
underlying_executor_(std::move(underlying_executor)),
|
||||||
|
local_scopes_(std::move(local_scopes)),
|
||||||
|
var_infos_(std::move(var_infos)),
|
||||||
|
places_(std::move(places)) {}
|
||||||
|
|
||||||
|
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
|
||||||
|
const std::vector<std::string> &fetch_tensors) {
|
||||||
|
if (drop_scope_counter_ == 0) {
|
||||||
|
// Create local scopes.
|
||||||
|
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
|
||||||
|
auto &scope = *it;
|
||||||
|
Scope &local_scope = scope->NewScope();
|
||||||
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
||||||
|
&local_scope;
|
||||||
|
|
||||||
|
for (auto &info : var_infos_) {
|
||||||
|
if (scope->FindVar(info.name_) != nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (info.persistable_) { // Persistable
|
||||||
|
InitializeVariable(scope->Var(info.name_), info.type_);
|
||||||
|
} else {
|
||||||
|
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fetch_data = underlying_executor_->Run(fetch_tensors);
|
||||||
|
drop_scope_counter_ += 1;
|
||||||
|
if (!fetch_tensors.empty() ||
|
||||||
|
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
|
||||||
|
drop_scope_counter_ = 0;
|
||||||
|
// Wait All computational streams
|
||||||
|
for (auto p : places_) {
|
||||||
|
platform::DeviceContextPool::Instance().Get(p)->Wait();
|
||||||
|
}
|
||||||
|
for (auto &scope : local_scopes_) {
|
||||||
|
auto &local_scope =
|
||||||
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
|
||||||
|
scope->DeleteScope(local_scope);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fetch_data;
|
||||||
|
}
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,53 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/details/execution_strategy.h"
|
||||||
|
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
|
||||||
|
#include "paddle/fluid/framework/scope.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
struct VariableInfo {
|
||||||
|
std::string name_;
|
||||||
|
proto::VarType::Type type_;
|
||||||
|
bool persistable_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
|
||||||
|
public:
|
||||||
|
ScopeBufferedSSAGraphExecutor(
|
||||||
|
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
|
||||||
|
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
|
||||||
|
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
|
||||||
|
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t drop_scope_counter_{0};
|
||||||
|
|
||||||
|
ExecutionStrategy strategy_;
|
||||||
|
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
|
||||||
|
std::vector<Scope*> local_scopes_;
|
||||||
|
std::vector<VariableInfo> var_infos_;
|
||||||
|
std::vector<platform::Place> places_;
|
||||||
|
};
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue