commit
df59889984
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
# Design for TensorArray
|
||||
TensorArray as a new concept is borrowed from TensorFlow,
|
||||
it is meant to be used with dynamic iteration primitives such as `while_loop` and `map_fn`.
|
||||
|
||||
This concept can be used to support our new design of dynamic operations, and help to refactor some existing variant-sentence-related layers,
|
||||
such as `RecurrentGradientMachine`.
|
||||
|
||||
In [our design for dynamic RNN](https://github.com/PaddlePaddle/Paddle/pull/4401),
|
||||
`TensorArray` is used to segment inputs and store states in all time steps.
|
||||
By providing some methods similar to a C++ array,
|
||||
the definition of some state-based dynamic models such as RNN could be more natural and highly flexible.
|
||||
|
||||
## Dynamic-Related Methods
|
||||
Some basic methods should be proposed as follows:
|
||||
|
||||
### stack()
|
||||
Pack the values in a `TensorArray` into a tensor with rank one higher than each tensor in `values`.
|
||||
### unstack(axis=0)
|
||||
Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
|
||||
### concat()
|
||||
Return the values in the `TensorArray` as a concatenated Tensor.
|
||||
### write(index, value, data_shared=true)
|
||||
Write value into index of the TensorArray.
|
||||
### read(index)
|
||||
Read the value at location `index` in the `TensorArray`.
|
||||
### size()
|
||||
Return the number of values.
|
||||
|
||||
## LoDTensor-related Supports
|
||||
The `RecurrentGradientMachine` in Paddle serves as a flexible RNN layer; it takes variant length sequences as input,
|
||||
because each step of RNN could only take a tensor-represented batch of data as input,
|
||||
some preprocess should be taken on the inputs such as sorting the sentences by their length in descending order and cut each word and pack to new batches.
|
||||
|
||||
Such cut-like operations can be embedded into `TensorArray` as general methods called `unpack` and `pack`.
|
||||
|
||||
With these two methods, a variant-sentence-RNN can be implemented like
|
||||
|
||||
```c++
|
||||
// input is the varient-length data
|
||||
LodTensor sentence_input(xxx);
|
||||
TensorArray ta;
|
||||
Tensor indice_map;
|
||||
Tensor boot_state = xxx; // to initialize rnn's first state
|
||||
TensorArray::unpack(input, 1/*level*/, true/*sort_by_length*/, &ta, &indice_map);
|
||||
TessorArray step_outputs;
|
||||
TensorArray states;
|
||||
|
||||
for (int step = 0; step = ta.size(); step++) {
|
||||
auto state = states.read(step);
|
||||
// rnnstep is a function which acts like a step of RNN
|
||||
auto step_input = ta.read(step);
|
||||
auto step_output = rnnstep(step_input, state);
|
||||
step_outputs.write(step_output, true/*data_shared*/);
|
||||
}
|
||||
|
||||
// rnn_output is the final output of an rnn
|
||||
LoDTensor rnn_output = ta.pack(ta, indice_map);
|
||||
```
|
||||
the code above shows that by embedding the LoDTensor-related preprocess operations into `TensorArray`,
|
||||
the implementation of a RNN that supports varient-length sentences is far more concise than `RecurrentGradientMachine` because the latter mixes all the codes together, hard to read and extend.
|
||||
|
||||
|
||||
some details are as follows.
|
||||
|
||||
### unpack(level, sort_by_length)
|
||||
Split LodTensor in some `level` and generate batches, if set `sort_by_length`, will sort by length.
|
||||
|
||||
Returns:
|
||||
|
||||
- a new `TensorArray`, whose values are LodTensors and represents batches of data.
|
||||
- an int32 Tensor, which stores the map from the new batch's indices to original LoDTensor
|
||||
### pack(level, indices_map)
|
||||
Recover the original LoD-arranged LoDTensor with the values in a `TensorArray` and `level` and `indices_map`.
|
@ -0,0 +1,146 @@
|
||||
## How to use Eigen in Paddle
|
||||
|
||||
Essentially, a neural network is a compute graph. T data needed for the computation is stored in `Tensor`s and its computation procedure is described by `Operator`s. An `Operator` calls the `Compute` interface in its corresponding `OpKernel` and operates on the `Tensor`.
|
||||
|
||||
|
||||
### Eigen Tensor Module
|
||||
|
||||
The Eigen Tensor module supports powerful element-wise computation. In addition, a piece of code written using it can be run on both the CPU and the GPU.
|
||||
|
||||
Note that Eigen Tensor is still being actively developed, so its tests are not completely covered and its documentation may be sparse.
|
||||
|
||||
For details on Eigen Tensor module, please see [doc 1](https://github.com/RLovelett/eigen/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md) and [doc 2](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md).
|
||||
|
||||
|
||||
### paddle::framework::Tensor
|
||||
|
||||
Paddle Tensor's is defined in the framework directory with the following interface:
|
||||
|
||||
```cpp
|
||||
class Tensor {
|
||||
public:
|
||||
/*! Return a pointer to mutable memory block. */
|
||||
template <typename T>
|
||||
inline T* data();
|
||||
|
||||
/**
|
||||
* @brief Return a pointer to mutable memory block.
|
||||
* @note If not exist, then allocation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data(platform::Place place);
|
||||
|
||||
/**
|
||||
* @brief Return a pointer to mutable memory block.
|
||||
*
|
||||
* @param[in] dims The dimensions of the memory block.
|
||||
* @param[in] place The place of the memory block.
|
||||
*
|
||||
* @note If not exist, then allocation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data(DDim dims, platform::Place place);
|
||||
|
||||
/*! Resize the dimensions of the memory block. */
|
||||
inline Tensor& Resize(const DDim& dims);
|
||||
|
||||
/*! Return the dimensions of the memory block. */
|
||||
inline const DDim& dims() const;
|
||||
|
||||
private:
|
||||
/*! holds the memory block if allocated. */
|
||||
std::shared_ptr<Placeholder> holder_;
|
||||
|
||||
/*! points to dimensions of memory block. */
|
||||
DDim dim_;
|
||||
};
|
||||
```
|
||||
|
||||
`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configure its shape, and then call `mutuable_data` to allocate the actual memory.
|
||||
|
||||
```cpp
|
||||
paddle::framework::Tensor t;
|
||||
paddle::platform::CPUPlace place;
|
||||
// set size first
|
||||
t.Resize({2, 3});
|
||||
// allocate memory on CPU later
|
||||
t.mutable_data(place);
|
||||
```
|
||||
|
||||
### paddle::framework::Tensor Usage
|
||||
`AddOp` demonstrates Tensor's usage.
|
||||
|
||||
- InferShape
|
||||
|
||||
When computing a neural network's compute graph, first call every `Operator`'s `InferShape` method, and use `Resize` to configure the size of the output tensor.
|
||||
|
||||
```cpp
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
|
||||
ctx.Input<Tensor>("Y")->dims(),
|
||||
"Two input of Add Op's dimension must be same.");
|
||||
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
- Run
|
||||
|
||||
```cpp
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input0 = context.Input<Tensor>("X");
|
||||
auto* input1 = context.Input<Tensor>("Y");
|
||||
auto* output = context.Output<Tensor>("Out");
|
||||
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto x = EigenVector<T>::Flatten(*input0);
|
||||
auto y = EigenVector<T>::Flatten(*input1);
|
||||
auto z = EigenVector<T>::Flatten(*output);
|
||||
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
z.device(place) = x + y;
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### paddle::framework::Tensor到EigenTensor的转换
|
||||
|
||||
As shown above, in actual computation, we need to transform the input and output `Tensor`s into formats Eigen supports. We show some functions in [eigen.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen.h) to implement the transformation from `paddle::framework::Tensor`to `EigenTensor/EigenMatrix/EigenVector/EigenScalar`.
|
||||
|
||||
Using EigenTensor as an example:
|
||||
|
||||
```cpp
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
|
||||
```
|
||||
|
||||
`From` is an interfacing method provided by the EigenTensor template, which implements the transformation from a `paddle::framework::Tensor` object to an EigenTensor. Since `rank` is a template parameter, it needs to be explicitly specified at the time of the transformation.
|
||||
|
||||
In Eigen, tensors with different ranks are different types, with `Vector` bring a rank-1 instance. Note that `EigenVector<T>::From` uses a transformation from an 1-dimensional Paddle tensor to a 1-dimensional Eigen tensor while `EigenVector<T>::Flatten` reshapes a paddle tensor and flattens it into a 1-dimensional Eigen tensor. Both resulting tensors are still typed EigenVector.
|
||||
|
||||
For more transformations, see the [unit tests](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen_test.cc) in the `eigen_test.cc` file.
|
||||
|
||||
|
||||
|
||||
### Implementing Computation
|
||||
|
||||
While computing, the device interface is needed from the EigenTensors on the left hand side of the assignments. Note that the computation between EigenTensors only changes the data originally inthe Tensor and does not change all the shape information associated with the Tensor.
|
||||
|
||||
```cpp
|
||||
auto x = EigenVector<T>::Flatten(*input0);
|
||||
auto y = EigenVector<T>::Flatten(*input1);
|
||||
auto z = EigenVector<T>::Flatten(*output);
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
z.device(place) = x + y;
|
||||
```
|
||||
|
||||
In this code segment, input0/input1/output can be Tensors of arbitrary dimension. We are calling Flatten from EigenVector, transforming a tensor of any dimension into a 1-dimensional EigenVector. After completing computation, input0/input1/output will retain the same shape information, and they can be resized using the `Resize` interface.
|
||||
|
||||
Because the Eigen Tensor module is under-documented, please refer to `OpKernel`'s computation code in TensorFlow's [kernel module documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/kernels).
|
@ -0,0 +1,89 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/block_desc.h"
|
||||
#include "paddle/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
VarDescBind *BlockDescBind::NewVar(const std::string &name) {
|
||||
need_update_ = true;
|
||||
auto it = vars_.find(name);
|
||||
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
|
||||
auto var = new VarDescBind(name);
|
||||
vars_[name].reset(var);
|
||||
return var;
|
||||
}
|
||||
|
||||
VarDescBind *BlockDescBind::Var(const std::string &name) const {
|
||||
auto it = vars_.find(name);
|
||||
PADDLE_ENFORCE(it != vars_.end(),
|
||||
"Can not find variable %s in current block.", name);
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
|
||||
std::vector<VarDescBind *> res;
|
||||
for (const auto &p : vars_) {
|
||||
res.push_back(p.second.get());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
OpDescBind *BlockDescBind::AppendOp() {
|
||||
need_update_ = true;
|
||||
ops_.emplace_back(new OpDescBind());
|
||||
return ops_.back().get();
|
||||
}
|
||||
|
||||
OpDescBind *BlockDescBind::PrependOp() {
|
||||
need_update_ = true;
|
||||
ops_.emplace_front(new OpDescBind());
|
||||
return ops_.front().get();
|
||||
}
|
||||
|
||||
std::vector<OpDescBind *> BlockDescBind::AllOps() const {
|
||||
std::vector<OpDescBind *> res;
|
||||
for (const auto &op : ops_) {
|
||||
res.push_back(op.get());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void BlockDescBind::Sync() {
|
||||
if (need_update_) {
|
||||
auto &op_field = *this->desc_->mutable_ops();
|
||||
op_field.Clear();
|
||||
op_field.Reserve(static_cast<int>(ops_.size()));
|
||||
for (auto &op_desc : ops_) {
|
||||
op_field.AddAllocated(op_desc->Proto());
|
||||
}
|
||||
need_update_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
BlockDescBind *BlockDescBind::ParentBlock() const {
|
||||
if (this->desc_->parent_idx() == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
|
||||
}
|
||||
|
||||
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
|
||||
BlockDesc *desc = block.RawPtr();
|
||||
this->attrs_[name] = desc;
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,71 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/op_desc.h"
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class ProgramDescBind;
|
||||
|
||||
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
|
||||
// read/write speed. Only when we want the protobuf message, the local changes
|
||||
// will be synchronized (by `Sync` method).
|
||||
|
||||
class BlockDescBind {
|
||||
public:
|
||||
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
|
||||
: prog_(prog), desc_(desc), need_update_(false) {}
|
||||
|
||||
BlockDescBind(const BlockDescBind &o) = delete;
|
||||
BlockDescBind &operator=(const BlockDescBind &o) = delete;
|
||||
|
||||
int32_t ID() const { return desc_->idx(); }
|
||||
|
||||
int32_t Parent() const { return desc_->parent_idx(); }
|
||||
|
||||
VarDescBind *NewVar(const std::string &name_bytes);
|
||||
|
||||
VarDescBind *Var(const std::string &name_bytes) const;
|
||||
|
||||
std::vector<VarDescBind *> AllVars() const;
|
||||
|
||||
BlockDescBind *ParentBlock() const;
|
||||
|
||||
OpDescBind *AppendOp();
|
||||
|
||||
OpDescBind *PrependOp();
|
||||
|
||||
std::vector<OpDescBind *> AllOps() const;
|
||||
|
||||
void Sync();
|
||||
|
||||
BlockDesc *RawPtr() { return desc_; }
|
||||
|
||||
private:
|
||||
ProgramDescBind *prog_; // not_own
|
||||
BlockDesc *desc_; // not_own
|
||||
bool need_update_;
|
||||
|
||||
std::deque<std::unique_ptr<OpDescBind>> ops_;
|
||||
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,36 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <typeindex>
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
inline DataType ToDataType(std::type_index type) {
|
||||
if (typeid(float).hash_code() == type.hash_code()) {
|
||||
return DataType::FP32;
|
||||
} else if (typeid(double).hash_code() == type.hash_code()) {
|
||||
return DataType::FP64;
|
||||
} else if (typeid(int).hash_code() == type.hash_code()) {
|
||||
return DataType::INT32;
|
||||
} else {
|
||||
PADDLE_THROW("Not supported");
|
||||
return static_cast<DataType>(-1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,144 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_desc.h"
|
||||
#include "paddle/framework/block_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
OpDesc *OpDescBind::Proto() {
|
||||
Sync();
|
||||
return &op_desc_;
|
||||
}
|
||||
|
||||
const std::vector<std::string> &OpDescBind::Input(
|
||||
const std::string &name) const {
|
||||
auto it = inputs_.find(name);
|
||||
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
|
||||
Type());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::InputNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(this->inputs_.size());
|
||||
for (auto &ipt : this->inputs_) {
|
||||
retv.push_back(ipt.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetInput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args) {
|
||||
need_update_ = true;
|
||||
inputs_[param_name] = args;
|
||||
}
|
||||
|
||||
const std::vector<std::string> &OpDescBind::Output(
|
||||
const std::string &name) const {
|
||||
auto it = outputs_.find(name);
|
||||
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
|
||||
name, Type());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::OutputNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(this->outputs_.size());
|
||||
for (auto &ipt : this->outputs_) {
|
||||
retv.push_back(ipt.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetOutput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args) {
|
||||
need_update_ = true;
|
||||
this->outputs_[param_name] = args;
|
||||
}
|
||||
|
||||
AttrType OpDescBind::GetAttrType(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return static_cast<AttrType>(it->second.which() - 1);
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::AttrNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(attrs_.size());
|
||||
for (auto &attr : attrs_) {
|
||||
retv.push_back(attr.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
|
||||
this->attrs_[name] = v;
|
||||
need_update_ = true;
|
||||
}
|
||||
|
||||
void OpDescBind::SetAttrMap(
|
||||
const std::unordered_map<std::string, Attribute> &attr_map) {
|
||||
attrs_ = attr_map;
|
||||
need_update_ = true;
|
||||
}
|
||||
|
||||
Attribute OpDescBind::GetAttr(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int OpDescBind::GetBlockAttr(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return boost::get<BlockDesc *>(it->second)->idx();
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
|
||||
const {
|
||||
return attrs_;
|
||||
}
|
||||
|
||||
void OpDescBind::Sync() {
|
||||
if (need_update_) {
|
||||
this->op_desc_.mutable_inputs()->Clear();
|
||||
for (auto &ipt : inputs_) {
|
||||
auto *input = op_desc_.add_inputs();
|
||||
input->set_parameter(ipt.first);
|
||||
VectorToRepeated(ipt.second, input->mutable_arguments());
|
||||
}
|
||||
|
||||
this->op_desc_.mutable_outputs()->Clear();
|
||||
for (auto &opt : outputs_) {
|
||||
auto *output = op_desc_.add_outputs();
|
||||
output->set_parameter(opt.first);
|
||||
VectorToRepeated(opt.second, output->mutable_arguments());
|
||||
}
|
||||
|
||||
this->op_desc_.mutable_attrs()->Clear();
|
||||
for (auto &attr : attrs_) {
|
||||
auto *attr_desc = op_desc_.add_attrs();
|
||||
attr_desc->set_name(attr.first);
|
||||
attr_desc->set_type(
|
||||
static_cast<framework::AttrType>(attr.second.which() - 1));
|
||||
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
|
||||
}
|
||||
|
||||
need_update_ = false;
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,112 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/attribute.h"
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class BlockDescBind;
|
||||
|
||||
class OpDescBind {
|
||||
public:
|
||||
OpDesc *Proto();
|
||||
|
||||
std::string Type() const { return op_desc_.type(); }
|
||||
|
||||
void SetType(const std::string &type) { op_desc_.set_type(type); }
|
||||
|
||||
const std::vector<std::string> &Input(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> InputNames() const;
|
||||
|
||||
void SetInput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args);
|
||||
|
||||
const std::vector<std::string> &Output(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> OutputNames() const;
|
||||
|
||||
void SetOutput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args);
|
||||
|
||||
std::string DebugString() { return this->Proto()->DebugString(); }
|
||||
|
||||
bool HasAttr(const std::string &name) const {
|
||||
return attrs_.find(name) != attrs_.end();
|
||||
}
|
||||
|
||||
AttrType GetAttrType(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> AttrNames() const;
|
||||
|
||||
void SetAttr(const std::string &name, const Attribute &v);
|
||||
|
||||
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
||||
|
||||
// Only be used in C++
|
||||
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
|
||||
|
||||
Attribute GetAttr(const std::string &name) const;
|
||||
|
||||
int GetBlockAttr(const std::string &name) const;
|
||||
|
||||
// Only be used in C++
|
||||
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
|
||||
|
||||
private:
|
||||
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
||||
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
||||
mutable OpDesc::Attr *attr_;
|
||||
void operator()(int v) const { attr_->set_i(v); }
|
||||
void operator()(float v) const { attr_->set_f(v); }
|
||||
void operator()(const std::string &v) const { attr_->set_s(v); }
|
||||
void operator()(bool b) const { attr_->set_b(b); }
|
||||
|
||||
void operator()(const std::vector<int> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_ints());
|
||||
}
|
||||
void operator()(const std::vector<float> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_floats());
|
||||
}
|
||||
void operator()(const std::vector<std::string> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_strings());
|
||||
}
|
||||
void operator()(const std::vector<bool> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_bools());
|
||||
}
|
||||
void operator()(BlockDesc *desc) const {
|
||||
attr_->set_block_idx(desc->idx());
|
||||
}
|
||||
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
|
||||
};
|
||||
|
||||
void Sync();
|
||||
|
||||
OpDesc op_desc_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> inputs_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> outputs_;
|
||||
std::unordered_map<std::string, Attribute> attrs_;
|
||||
|
||||
// need_update_ indicate there some local changes not be synchronized. If
|
||||
// local changes should be synchronized, need_update_ should be set to true.
|
||||
bool need_update_{false};
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,60 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/program_desc.h"
|
||||
#include "paddle/framework/block_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using ProgDescMap =
|
||||
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
|
||||
static ProgDescMap *g_bind_map = nullptr;
|
||||
|
||||
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
|
||||
if (g_bind_map == nullptr) {
|
||||
g_bind_map = new ProgDescMap();
|
||||
}
|
||||
auto &map = *g_bind_map;
|
||||
auto &ptr = map[prog];
|
||||
|
||||
if (ptr == nullptr) {
|
||||
ptr.reset(new ProgramDescBind(prog));
|
||||
}
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
|
||||
auto *b = prog_->add_blocks();
|
||||
b->set_parent_idx(parent.ID());
|
||||
b->set_idx(prog_->blocks_size() - 1);
|
||||
blocks_.emplace_back(new BlockDescBind(this, b));
|
||||
return blocks_.back().get();
|
||||
}
|
||||
|
||||
ProgramDesc *ProgramDescBind::Proto() {
|
||||
for (auto &block : blocks_) {
|
||||
block->Sync();
|
||||
}
|
||||
return prog_;
|
||||
}
|
||||
|
||||
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) {
|
||||
prog_ = prog;
|
||||
for (auto &block : *prog->mutable_blocks()) {
|
||||
blocks_.emplace_back(new BlockDescBind(this, &block));
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue