|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/executor.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
#include "paddle/platform/device_context.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) {
|
|
|
|
|
|
|
|
|
|
class ExecutorImpl : public Executor {
|
|
|
|
|
public:
|
|
|
|
|
ExecutorImpl(const platform::DeviceContext* ctx, const ProgramDesc* pdesc,
|
|
|
|
|
bool is_linear)
|
|
|
|
|
: device_context_(ctx),
|
|
|
|
|
ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx,
|
|
|
|
|
const ProgramDesc* pdesc, bool is_linear)
|
|
|
|
|
: scope_(scope),
|
|
|
|
|
device_context_(ctx),
|
|
|
|
|
program_desc_(pdesc),
|
|
|
|
|
view_(ProgramDescView::Create(is_linear)) {}
|
|
|
|
|
|
|
|
|
@ -73,6 +75,7 @@ class ExecutorImpl : public Executor {
|
|
|
|
|
void Initialize();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
Scope* scope_;
|
|
|
|
|
const platform::DeviceContext* device_context_;
|
|
|
|
|
const ProgramDesc* program_desc_;
|
|
|
|
|
ProgramDescView* view_;
|
|
|
|
@ -97,6 +100,12 @@ platform::CUDADeviceContext* GetCUDADeviceContext() {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
framework::Scope* GetScope() {
|
|
|
|
|
static std::unique_ptr<framework::Scope> g_scope =
|
|
|
|
|
make_unique<framework::Scope>();
|
|
|
|
|
return g_scope.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Executor* NewLocalExecutor(const platform::Place& place,
|
|
|
|
|
const ProgramDesc& pdesc, bool is_linear) {
|
|
|
|
|
platform::DeviceContext* device_context = nullptr;
|
|
|
|
@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place,
|
|
|
|
|
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return new ExecutorImpl(device_context, &pdesc, is_linear);
|
|
|
|
|
return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExecutorImpl::Run() {
|
|
|
|
|
// operators running
|
|
|
|
|
scope_->NewVar();
|
|
|
|
|
device_context_->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|