|
|
@ -23,6 +23,7 @@ limitations under the License. */
|
|
|
|
#include "paddle/framework/lod_tensor_array.h"
|
|
|
|
#include "paddle/framework/lod_tensor_array.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/reader.h"
|
|
|
|
#include "paddle/framework/reader.h"
|
|
|
|
|
|
|
|
#include "paddle/operators/nccl/nccl_gpu_common.h" // platform::Communicator
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
#include "paddle/platform/profiler.h"
|
|
|
|
#include "paddle/platform/profiler.h"
|
|
|
|
|
|
|
|
|
|
|
@ -53,6 +54,8 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
|
|
|
|
var->GetMutable<LoDTensorArray>();
|
|
|
|
var->GetMutable<LoDTensorArray>();
|
|
|
|
} else if (var_type == proto::VarDesc::PLACE_LIST) {
|
|
|
|
} else if (var_type == proto::VarDesc::PLACE_LIST) {
|
|
|
|
var->GetMutable<platform::PlaceList>();
|
|
|
|
var->GetMutable<platform::PlaceList>();
|
|
|
|
|
|
|
|
} else if (var_type == proto::VarDesc::NCCL_COM) {
|
|
|
|
|
|
|
|
var->GetMutable<platform::Communicator>();
|
|
|
|
} else if (var_type == proto::VarDesc::READER) {
|
|
|
|
} else if (var_type == proto::VarDesc::READER) {
|
|
|
|
var->GetMutable<ReaderHolder>();
|
|
|
|
var->GetMutable<ReaderHolder>();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -118,13 +121,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& op_desc : block.AllOps()) {
|
|
|
|
for (auto& op_desc : block.AllOps()) {
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
|
|
|
|
VLOG(4) << op->DebugStringEx(local_scope);
|
|
|
|
VLOG(3) << op->DebugStringEx(local_scope);
|
|
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
platform::RecordEvent record_event(op->Type(), pool.Get(place_));
|
|
|
|
platform::RecordEvent record_event(op->Type(), pool.Get(place_));
|
|
|
|
|
|
|
|
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
VLOG(3) << op->DebugStringEx(local_scope);
|
|
|
|
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
|
|
|
|
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
|
|
|
|
<< memory::memory_usage(place_);
|
|
|
|
<< memory::memory_usage(place_);
|
|
|
|