Dynamic Graph first prototype (#11415)
parent
a77dfeee56
commit
d827c6e87a
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if(APPLE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
|
||||
endif(APPLE)
|
||||
|
||||
cc_library(tape_variable SRCS variable.cc DEPS ${FLUID_CORE_MODULES})
|
||||
cc_library(tape SRCS tape.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} tape_variable)
|
||||
|
||||
cc_test(test_tape
|
||||
SRCS test_tape.cc
|
||||
DEPS tape tape_variable)
|
@ -0,0 +1,246 @@
|
||||
# Dynamic Graph on Fluid
|
||||
|
||||
PaddlePaddle Fluid is targeting the autodiff without tape, which, however, is very challenging and we are still way from there. DyNet and PyTorch provide a good design idea, the *tape*, that significantly eases the challenge. Also, DyNet provides a C++ API that is as convenient as Python but with higher efficiency and could conveniently integrate with industrial/production systems. This package, `tape`, combines the good of
|
||||
|
||||
1. tape from PyTorch and DyNet
|
||||
2. C++ API and core from DyNet
|
||||
3. rich set of operators from PaddlePaddle
|
||||
|
||||
## Overview
|
||||
|
||||
We can implement Dynet-like Tape(See this survey) by wrapping Paddle Fluid's `Operator`
|
||||
and `Variable`.
|
||||
|
||||
The user API is straight forward since
|
||||
|
||||
1. it is imperative. And it uses host language's control flow logic.
|
||||
1. it avoids extra concepts such as `Scope` and `Executor`.
|
||||
|
||||
All of these benefits come at the cost of just adding one line `reset_global_tape`
|
||||
at every iteration.
|
||||
|
||||
## Code Structure
|
||||
|
||||
In short, the `Tape` contains a vector of `OpHandle`s. And an `OpHandle` contains its
|
||||
`type`, the pointers to the `Variable`s, and necessary attributes.
|
||||
|
||||
```c++
|
||||
class Variable {
|
||||
public:
|
||||
VriableHandle Grad(); // returns its gradient variable
|
||||
private:
|
||||
framework::VarDesc desc_; // compile time infershape, necessary for lazy execution
|
||||
framework::Variable var_; // run time variable, holds data memory
|
||||
};
|
||||
|
||||
using VariableHandle = shared_ptr<Variable>;
|
||||
|
||||
struct OpHandle {
|
||||
string type_;
|
||||
map<string, vector<VariableHandle>> inputs_;
|
||||
map<string, vector<VariableHandle>> outputs_;
|
||||
AttributeMap attrs_;
|
||||
};
|
||||
|
||||
class Tape {
|
||||
public:
|
||||
void AddOp(OpHandle); // add op
|
||||
void Forward(); // execute the tape_
|
||||
void Backward(); // execute the backward of the tape_
|
||||
private:
|
||||
vector<OpHandle> tape_;
|
||||
};
|
||||
```
|
||||
|
||||
We uses `Function` to indicate layers. It takes care of parameter
|
||||
initialization and `AddOp` to the Tape when it is called.
|
||||
|
||||
```c++
|
||||
class Linear {
|
||||
public:
|
||||
Linear(int in_dim, int out_dim, const std::string &act)
|
||||
: w_(new Variable("LinearWeight")),
|
||||
b_(new Variable("LinearBias")),
|
||||
act_(act) {
|
||||
Tape init_tape;
|
||||
|
||||
std::string initializer = "fill_constant";
|
||||
framework::AttributeMap attrs;
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{in_dim, out_dim};
|
||||
attrs["value"] = 1.0f;
|
||||
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
|
||||
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{out_dim};
|
||||
attrs["value"] = 1.0f;
|
||||
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
|
||||
|
||||
init_tape.Forward();
|
||||
}
|
||||
|
||||
VariableHandle operator()(VariableHandle input) {
|
||||
VariableHandle pre_bias(new Variable("linear"));
|
||||
get_global_tape().AddOp("mul",
|
||||
{{"X", {input}}, {"Y", {w_}}},
|
||||
{{"Out", {pre_bias}}},
|
||||
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
|
||||
VariableHandle pre_act(new Variable("linear"));
|
||||
get_global_tape().AddOp("elementwise_add",
|
||||
{{"X", {pre_bias}}, {"Y", {b_}}},
|
||||
{{"Out", {pre_act}}},
|
||||
{{"axis", 1}});
|
||||
VariableHandle post_act(new Variable("linear"));
|
||||
get_global_tape().AddOp(act_,
|
||||
{{"X", {pre_act}}},
|
||||
{{"Out", {post_act}}},
|
||||
{});
|
||||
return post_act;
|
||||
}
|
||||
|
||||
std::vector<VariableHandle> Params() { return {w_, b_}; }
|
||||
|
||||
private:
|
||||
VariableHandle w_;
|
||||
VariableHandle b_;
|
||||
std::string act_;
|
||||
};
|
||||
```
|
||||
|
||||
## User API
|
||||
|
||||
```c++
|
||||
// Model function
|
||||
paddle::tape::Linear linear1(3, 3, "relu"); // init weight and bias
|
||||
paddle::tape::Linear linear2(3, 3, "relu"); // init weight and bias
|
||||
paddle::tape::Mean mean;
|
||||
|
||||
// Optimizer
|
||||
paddle::tape::SGD sgd(0.001);
|
||||
|
||||
// Data Feeder
|
||||
paddle::tape::Fill data_feeder(...);
|
||||
VariableHandle input(new paddle::tape::Variable("input"));
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
reset_global_tape();
|
||||
|
||||
data_feeder(input);
|
||||
|
||||
auto loss = mean(linear2(linear1(input))); // compile time InferShape & InferVarType
|
||||
LOG(INFO) << loss.value(); // Run forward up to loss
|
||||
|
||||
// Run backward, store gradient of w at w->Grad()
|
||||
get_global_tape.Backward(loss);
|
||||
|
||||
// Update w
|
||||
sgd(linear1.Params());
|
||||
sgd(linear2.Params());
|
||||
}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary></summary>
|
||||
digraph G {
|
||||
|
||||
subgraph cluster_0 {
|
||||
node [shape=record,style=filled];
|
||||
style=filled;
|
||||
color=lightgrey;
|
||||
linear1 [label="{type: mul | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1}} | {output |<before_bias1> Out: before_bias1}}"];
|
||||
elementwise_add1 [label="{type: elementwise_add | {input | {<before_bias1>X: before_bias1 |<bias1> Y: bias1}} | {output |<before_act1> Out: before_act1}}"];
|
||||
relu1 [label="{type: relu | {input | {<before_act1>X: before_act1 }} | {output |<after_act1> Out: after_act1}}"];
|
||||
|
||||
linear1 -> elementwise_add1->relu1;
|
||||
label = "forward tape";
|
||||
}
|
||||
|
||||
linear1:before_mul1->before_mul1
|
||||
linear1:weight1->weight1
|
||||
linear1:before_bias1->before_bias1
|
||||
|
||||
elementwise_add1:bias1->bias1
|
||||
elementwise_add1:before_bias1->before_bias1
|
||||
elementwise_add1:before_act1->before_act1
|
||||
|
||||
relu1:before_act1->before_act1
|
||||
relu1:after_act1->after_act1
|
||||
|
||||
subgraph cluster_1 {
|
||||
node [shape=record,style=filled];
|
||||
style=filled;
|
||||
color=lightgrey;
|
||||
linear1_grad [label="{type: mul_grad | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1|<before_bias1_grad> Out_grad: before_bias1_grad}} | {output |{<before_mul1_grad>X_grad: before_mul1_grad |<weight1_grad> Y_grad: weight1_grad}}}"];
|
||||
|
||||
elementwise_add1_grad [label="{type: elementwise_add_grad | {input | <before_act1_grad> Out_grad: before_act1_grad} | {output |{<before_bias1_grad>X_grad: before_bias1_grad |<bias1_grad> Y_grad: bias1_grad}}}"];
|
||||
|
||||
relu1_grad [label="{type: relu_grad | {input |<after_act1_grad> Out_grad: after_act1_grad} | {ouput | {<before_act1_grad>X_grad: before_act1_grad }}}"];
|
||||
|
||||
linear1_grad -> elementwise_add1_grad ->relu1_grad [dir=back];
|
||||
label = "backward tape";
|
||||
}
|
||||
|
||||
relu1_grad:after_act1_grad->after_act1_grad
|
||||
relu1_grad:before_act1_grad->before_act1_grad
|
||||
|
||||
elementwise_add1_grad:before_act1_grad->before_act1_grad
|
||||
elementwise_add1_grad:before_bias1_grad->before_bias1_grad
|
||||
elementwise_add1_grad:bias1_grad->bias1_grad
|
||||
|
||||
linear1_grad:before_mul1->before_mul1
|
||||
linear1_grad:weight1->weight1
|
||||
linear1_grad:before_bias1_grad->before_bias1_grad
|
||||
linear1_grad:before_mul1_grad->before_mul1_grad
|
||||
linear1_grad:weight1_grad->weight1_grad
|
||||
|
||||
|
||||
subgraph cluster_2 {
|
||||
node [shape=record];
|
||||
label = "Linear1";
|
||||
weight1
|
||||
bias1
|
||||
}
|
||||
|
||||
weight1 -> weight1_grad [ label="Grad()", style="dashed" ];
|
||||
bias1 -> bias1_grad [ label="Grad()", style="dashed"];
|
||||
|
||||
|
||||
|
||||
}
|
||||
</details>
|
||||
|
||||

|
||||
|
||||
## Code Reuse
|
||||
|
||||
We want to stay close to Paddle Fluid as much as possible.
|
||||
|
||||
### Reuse All Operators
|
||||
|
||||
As all Ops are registered at `OpInfoMap`, the effort of adding a new `Function`
|
||||
is about 10 lines of code, similar to expose an operator to Python.
|
||||
|
||||
### Reuse Compile Time InferShape and InferVarType
|
||||
|
||||
Note that all the symbolic information is stored at `tape::Varaible::desc_`, instead
|
||||
of `ProgramDesc.block.vars`, we create a temporary `BlockDesc` to do `InferShape` and
|
||||
`InferVarType` every time we `AddOp` to the tape.
|
||||
|
||||
### Reuse Operator::Run
|
||||
|
||||
We use smart pointer, instead of `Scope`, to manage memory. So we create a temporary
|
||||
`Scope` for every `Operator::Run()`.
|
||||
|
||||
## Possible Feature
|
||||
|
||||
### Release Memory on Backward
|
||||
|
||||
We can release memory aggressively. During backward, we can delete the OpHandle once
|
||||
we have finished its backward. Since all the variable is managed by smart pointer, the
|
||||
memory is automatically released when its `ref_count` goes to 0.
|
||||
|
||||
### Kernel Fusion
|
||||
|
||||
As a symbolic representation of the Tape is constructed first before the actual
|
||||
execution, it would be possible to perform graph optimization. One use case is kernel
|
||||
fusion.
|
After Width: | Height: | Size: 94 KiB |
@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/contrib/dynamic/tape.h"
|
||||
#include "paddle/contrib/dynamic/variable.h"
|
||||
#include "paddle/fluid/framework/type_defs.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace dynamic {
|
||||
|
||||
class Function {};
|
||||
|
||||
class Fill {
|
||||
public:
|
||||
Fill(const std::string &initializer, const framework::AttributeMap &attrs)
|
||||
: initializer_(initializer), attrs_(attrs) {}
|
||||
|
||||
void operator()(VariableHandle var) {
|
||||
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string initializer_;
|
||||
const framework::AttributeMap attrs_;
|
||||
};
|
||||
|
||||
class Mean {
|
||||
public:
|
||||
VariableHandle operator()(VariableHandle var) {
|
||||
VariableHandle out(new Variable("mean"));
|
||||
get_global_tape().AddOp("mean", {{"X", {var}}}, {{"Out", {out}}}, {});
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
class Linear {
|
||||
public:
|
||||
Linear(int in_dim, int out_dim, const std::string &act)
|
||||
: w_(new Variable("LinearWeight")),
|
||||
b_(new Variable("LinearBias")),
|
||||
act_(act) {
|
||||
Tape init_tape;
|
||||
|
||||
std::string initializer = "fill_constant";
|
||||
framework::AttributeMap attrs;
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{in_dim, out_dim};
|
||||
attrs["value"] = 1.0f;
|
||||
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
|
||||
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{out_dim};
|
||||
attrs["value"] = 1.0f;
|
||||
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
|
||||
|
||||
init_tape.Forward();
|
||||
}
|
||||
|
||||
VariableHandle operator()(VariableHandle input) {
|
||||
VariableHandle pre_bias(new Variable("linear"));
|
||||
get_global_tape().AddOp("mul",
|
||||
{{"X", {input}}, {"Y", {w_}}},
|
||||
{{"Out", {pre_bias}}},
|
||||
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
|
||||
VariableHandle pre_act(new Variable("linear"));
|
||||
get_global_tape().AddOp("elementwise_add",
|
||||
{{"X", {pre_bias}}, {"Y", {b_}}},
|
||||
{{"Out", {pre_act}}},
|
||||
{{"axis", 1}});
|
||||
VariableHandle post_act(new Variable("linear"));
|
||||
get_global_tape().AddOp(
|
||||
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
|
||||
return post_act;
|
||||
}
|
||||
|
||||
std::vector<VariableHandle> Params() { return {w_, b_}; }
|
||||
|
||||
private:
|
||||
VariableHandle w_;
|
||||
VariableHandle b_;
|
||||
std::string act_;
|
||||
};
|
||||
|
||||
class SGD {
|
||||
public:
|
||||
SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
|
||||
Tape init_tape;
|
||||
|
||||
std::string initializer = "fill_constant";
|
||||
framework::AttributeMap attrs;
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{1};
|
||||
attrs["value"] = learning_rate;
|
||||
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);
|
||||
|
||||
init_tape.Forward();
|
||||
}
|
||||
|
||||
void operator()(VariableHandle input) {
|
||||
Tape temp_tape;
|
||||
temp_tape.AddOp("sgd",
|
||||
{{"Param", {input}},
|
||||
{"LearningRate", {learning_rate_}},
|
||||
{"Grad", {input->Grad()}}},
|
||||
{{"ParamOut", {input}}},
|
||||
{});
|
||||
temp_tape.Forward();
|
||||
input->ResetGrad();
|
||||
}
|
||||
|
||||
private:
|
||||
VariableHandle learning_rate_;
|
||||
};
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/contrib/dynamic/variable.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace dynamic {
|
||||
|
||||
using VariableHandleMap = std::map<std::string, std::vector<VariableHandle>>;
|
||||
|
||||
struct OpHandle {
|
||||
OpHandle(const std::string &type,
|
||||
const VariableHandleMap &in_vars,
|
||||
const VariableHandleMap &out_vars,
|
||||
const framework::AttributeMap &attrs)
|
||||
: type_(type), inputs_(in_vars), outputs_(out_vars), attrs_(attrs) {}
|
||||
|
||||
std::string type_;
|
||||
VariableHandleMap inputs_;
|
||||
VariableHandleMap outputs_;
|
||||
framework::AttributeMap attrs_;
|
||||
};
|
||||
|
||||
class Tape {
|
||||
public:
|
||||
void AddOp(const std::string &type,
|
||||
const VariableHandleMap &in_vars,
|
||||
VariableHandleMap out_vars,
|
||||
const framework::AttributeMap &attrs);
|
||||
void Forward();
|
||||
void Backward(VariableHandle target);
|
||||
|
||||
private:
|
||||
bool has_been_backwarded_ = false;
|
||||
size_t current_position_ = 0;
|
||||
|
||||
std::vector<OpHandle> tape_;
|
||||
std::shared_ptr<Tape> backward_tape_;
|
||||
};
|
||||
|
||||
Tape &get_global_tape();
|
||||
|
||||
void reset_global_tape();
|
||||
}
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/contrib/dynamic/function.h"
|
||||
|
||||
using namespace paddle::dynamic;
|
||||
|
||||
TEST(Tape, TestMLP) {
|
||||
LOG(INFO) << "TestMLP";
|
||||
Linear linear1(3, 3, "relu");
|
||||
Linear linear2(3, 3, "relu");
|
||||
Mean mean;
|
||||
|
||||
SGD sgd(0.001);
|
||||
|
||||
std::string initializer = "fill_constant";
|
||||
paddle::framework::AttributeMap attrs;
|
||||
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
|
||||
attrs["shape"] = std::vector<int>{3, 3};
|
||||
attrs["value"] = 1.0f;
|
||||
Fill filler(initializer, attrs);
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
reset_global_tape();
|
||||
|
||||
VariableHandle input(new Variable("input"));
|
||||
filler(input);
|
||||
|
||||
auto loss = mean(linear2(linear1(input)));
|
||||
|
||||
get_global_tape().Backward(loss);
|
||||
|
||||
for (auto w : linear1.Params()) {
|
||||
sgd(w);
|
||||
}
|
||||
for (auto w : linear2.Params()) {
|
||||
sgd(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::vector<paddle::platform::Place> places;
|
||||
places.emplace_back(paddle::platform::CPUPlace());
|
||||
paddle::platform::DeviceContextPool::Init(places);
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/contrib/dynamic/variable.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace dynamic {
|
||||
|
||||
void Variable::InitializeVariable() {
|
||||
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
|
||||
framework::proto::VarType::Type var_type = desc_.GetType();
|
||||
if (var_type == framework::proto::VarType::LOD_TENSOR) {
|
||||
var_.GetMutable<framework::LoDTensor>();
|
||||
} else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
|
||||
var_.GetMutable<framework::SelectedRows>();
|
||||
} else {
|
||||
PADDLE_THROW("Variable type %d is not in [LOD_TENSOR, SELECTED_ROWS]",
|
||||
var_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/operator.h" // framework::kGradVarSuffix
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/variable.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace dynamic {
|
||||
|
||||
class Variable;
|
||||
using VariableHandle = std::shared_ptr<Variable>;
|
||||
|
||||
/*
|
||||
* Combination of
|
||||
* framework::VarDesc desc_;
|
||||
* framework::Variable var_;
|
||||
*/
|
||||
class Variable {
|
||||
public:
|
||||
Variable(const std::string pre_fix)
|
||||
: desc_(pre_fix + std::to_string(count())) {}
|
||||
|
||||
Variable(const std::string pre_fix, bool is_grad)
|
||||
: desc_(pre_fix + (is_grad ? framework::kGradVarSuffix
|
||||
: std::to_string(count()))) {}
|
||||
|
||||
~Variable() { LOG(INFO) << "Deleting " << Name(); }
|
||||
|
||||
// Instantiate LoDTensor/SelectedRow
|
||||
void InitializeVariable();
|
||||
|
||||
VariableHandle Grad() {
|
||||
if (grad_ == nullptr) {
|
||||
grad_.reset(new Variable(desc_.Name(), true));
|
||||
}
|
||||
|
||||
return grad_;
|
||||
}
|
||||
|
||||
void ResetGrad() { grad_ = nullptr; }
|
||||
|
||||
// Stochastic Gradient Descent with Momentum
|
||||
// VariableHandle Momentum ();
|
||||
|
||||
// void init(const std::string& initializer,
|
||||
// const framework::AttributeMap& attrs);
|
||||
|
||||
// void value() {};
|
||||
|
||||
const framework::VarDesc& Desc() const { return desc_; }
|
||||
framework::VarDesc* MutableDesc() { return &desc_; }
|
||||
|
||||
// TODO(tonyyang-svail): No need to expose name
|
||||
std::string Name() const { return desc_.Name(); }
|
||||
|
||||
framework::Variable* Var() { return &var_; }
|
||||
|
||||
private:
|
||||
int count() {
|
||||
static int counter = 0;
|
||||
return counter++;
|
||||
}
|
||||
|
||||
framework::VarDesc desc_;
|
||||
framework::Variable var_;
|
||||
|
||||
VariableHandle grad_;
|
||||
};
|
||||
}
|
||||
}
|
Loading…
Reference in new issue