tonyyang-svail-feed-op-desgin
commit
fb6a48c62d
@ -0,0 +1,146 @@
|
||||
## How to use Eigen in Paddle
|
||||
|
||||
Essentially, a neural network is a compute graph. T data needed for the computation is stored in `Tensor`s and its computation procedure is described by `Operator`s. An `Operator` calls the `Compute` interface in its corresponding `OpKernel` and operates on the `Tensor`.
|
||||
|
||||
|
||||
### Eigen Tensor Module
|
||||
|
||||
The Eigen Tensor module supports powerful element-wise computation. In addition, a piece of code written using it can be run on both the CPU and the GPU.
|
||||
|
||||
Note that Eigen Tensor is still being actively developed, so its tests are not completely covered and its documentation may be sparse.
|
||||
|
||||
For details on Eigen Tensor module, please see [doc 1](https://github.com/RLovelett/eigen/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md) and [doc 2](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md).
|
||||
|
||||
|
||||
### paddle::framework::Tensor
|
||||
|
||||
Paddle Tensor's is defined in the framework directory with the following interface:
|
||||
|
||||
```cpp
|
||||
class Tensor {
|
||||
public:
|
||||
/*! Return a pointer to mutable memory block. */
|
||||
template <typename T>
|
||||
inline T* data();
|
||||
|
||||
/**
|
||||
* @brief Return a pointer to mutable memory block.
|
||||
* @note If not exist, then allocation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data(platform::Place place);
|
||||
|
||||
/**
|
||||
* @brief Return a pointer to mutable memory block.
|
||||
*
|
||||
* @param[in] dims The dimensions of the memory block.
|
||||
* @param[in] place The place of the memory block.
|
||||
*
|
||||
* @note If not exist, then allocation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* mutable_data(DDim dims, platform::Place place);
|
||||
|
||||
/*! Resize the dimensions of the memory block. */
|
||||
inline Tensor& Resize(const DDim& dims);
|
||||
|
||||
/*! Return the dimensions of the memory block. */
|
||||
inline const DDim& dims() const;
|
||||
|
||||
private:
|
||||
/*! holds the memory block if allocated. */
|
||||
std::shared_ptr<Placeholder> holder_;
|
||||
|
||||
/*! points to dimensions of memory block. */
|
||||
DDim dim_;
|
||||
};
|
||||
```
|
||||
|
||||
`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configure its shape, and then call `mutuable_data` to allocate the actual memory.
|
||||
|
||||
```cpp
|
||||
paddle::framework::Tensor t;
|
||||
paddle::platform::CPUPlace place;
|
||||
// set size first
|
||||
t.Resize({2, 3});
|
||||
// allocate memory on CPU later
|
||||
t.mutable_data(place);
|
||||
```
|
||||
|
||||
### paddle::framework::Tensor Usage
|
||||
`AddOp` demonstrates Tensor's usage.
|
||||
|
||||
- InferShape
|
||||
|
||||
When computing a neural network's compute graph, first call every `Operator`'s `InferShape` method, and use `Resize` to configure the size of the output tensor.
|
||||
|
||||
```cpp
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
|
||||
ctx.Input<Tensor>("Y")->dims(),
|
||||
"Two input of Add Op's dimension must be same.");
|
||||
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
- Run
|
||||
|
||||
```cpp
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* input0 = context.Input<Tensor>("X");
|
||||
auto* input1 = context.Input<Tensor>("Y");
|
||||
auto* output = context.Output<Tensor>("Out");
|
||||
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto x = EigenVector<T>::Flatten(*input0);
|
||||
auto y = EigenVector<T>::Flatten(*input1);
|
||||
auto z = EigenVector<T>::Flatten(*output);
|
||||
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
|
||||
z.device(place) = x + y;
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### paddle::framework::Tensor到EigenTensor的转换
|
||||
|
||||
As shown above, in actual computation, we need to transform the input and output `Tensor`s into formats Eigen supports. We show some functions in [eigen.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen.h) to implement the transformation from `paddle::framework::Tensor`to `EigenTensor/EigenMatrix/EigenVector/EigenScalar`.
|
||||
|
||||
Using EigenTensor as an example:
|
||||
|
||||
```cpp
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
|
||||
```
|
||||
|
||||
`From` is an interfacing method provided by the EigenTensor template, which implements the transformation from a `paddle::framework::Tensor` object to an EigenTensor. Since `rank` is a template parameter, it needs to be explicitly specified at the time of the transformation.
|
||||
|
||||
In Eigen, tensors with different ranks are different types, with `Vector` bring a rank-1 instance. Note that `EigenVector<T>::From` uses a transformation from an 1-dimensional Paddle tensor to a 1-dimensional Eigen tensor while `EigenVector<T>::Flatten` reshapes a paddle tensor and flattens it into a 1-dimensional Eigen tensor. Both resulting tensors are still typed EigenVector.
|
||||
|
||||
For more transformations, see the [unit tests](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen_test.cc) in the `eigen_test.cc` file.
|
||||
|
||||
|
||||
|
||||
### Implementing Computation
|
||||
|
||||
While computing, the device interface is needed from the EigenTensors on the left hand side of the assignments. Note that the computation between EigenTensors only changes the data originally inthe Tensor and does not change all the shape information associated with the Tensor.
|
||||
|
||||
```cpp
|
||||
auto x = EigenVector<T>::Flatten(*input0);
|
||||
auto y = EigenVector<T>::Flatten(*input1);
|
||||
auto z = EigenVector<T>::Flatten(*output);
|
||||
auto place = context.GetEigenDevice<Place>();
|
||||
z.device(place) = x + y;
|
||||
```
|
||||
|
||||
In this code segment, input0/input1/output can be Tensors of arbitrary dimension. We are calling Flatten from EigenVector, transforming a tensor of any dimension into a 1-dimensional EigenVector. After completing computation, input0/input1/output will retain the same shape information, and they can be resized using the `Resize` interface.
|
||||
|
||||
Because the Eigen Tensor module is under-documented, please refer to `OpKernel`'s computation code in TensorFlow's [kernel module documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/kernels).
|
@ -0,0 +1,89 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/block_desc.h"
|
||||
#include "paddle/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
VarDescBind *BlockDescBind::NewVar(const std::string &name) {
|
||||
need_update_ = true;
|
||||
auto it = vars_.find(name);
|
||||
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
|
||||
auto var = new VarDescBind(name);
|
||||
vars_[name].reset(var);
|
||||
return var;
|
||||
}
|
||||
|
||||
VarDescBind *BlockDescBind::Var(const std::string &name) const {
|
||||
auto it = vars_.find(name);
|
||||
PADDLE_ENFORCE(it != vars_.end(),
|
||||
"Can not find variable %s in current block.", name);
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
|
||||
std::vector<VarDescBind *> res;
|
||||
for (const auto &p : vars_) {
|
||||
res.push_back(p.second.get());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
OpDescBind *BlockDescBind::AppendOp() {
|
||||
need_update_ = true;
|
||||
ops_.emplace_back(new OpDescBind());
|
||||
return ops_.back().get();
|
||||
}
|
||||
|
||||
OpDescBind *BlockDescBind::PrependOp() {
|
||||
need_update_ = true;
|
||||
ops_.emplace_front(new OpDescBind());
|
||||
return ops_.front().get();
|
||||
}
|
||||
|
||||
std::vector<OpDescBind *> BlockDescBind::AllOps() const {
|
||||
std::vector<OpDescBind *> res;
|
||||
for (const auto &op : ops_) {
|
||||
res.push_back(op.get());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void BlockDescBind::Sync() {
|
||||
if (need_update_) {
|
||||
auto &op_field = *this->desc_->mutable_ops();
|
||||
op_field.Clear();
|
||||
op_field.Reserve(static_cast<int>(ops_.size()));
|
||||
for (auto &op_desc : ops_) {
|
||||
op_field.AddAllocated(op_desc->Proto());
|
||||
}
|
||||
need_update_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
BlockDescBind *BlockDescBind::ParentBlock() const {
|
||||
if (this->desc_->parent_idx() == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
|
||||
}
|
||||
|
||||
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
|
||||
BlockDesc *desc = block.RawPtr();
|
||||
this->attrs_[name] = desc;
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,71 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/op_desc.h"
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class ProgramDescBind;
|
||||
|
||||
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
|
||||
// read/write speed. Only when we want the protobuf message, the local changes
|
||||
// will be synchronized (by `Sync` method).
|
||||
|
||||
class BlockDescBind {
|
||||
public:
|
||||
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
|
||||
: prog_(prog), desc_(desc), need_update_(false) {}
|
||||
|
||||
BlockDescBind(const BlockDescBind &o) = delete;
|
||||
BlockDescBind &operator=(const BlockDescBind &o) = delete;
|
||||
|
||||
int32_t ID() const { return desc_->idx(); }
|
||||
|
||||
int32_t Parent() const { return desc_->parent_idx(); }
|
||||
|
||||
VarDescBind *NewVar(const std::string &name_bytes);
|
||||
|
||||
VarDescBind *Var(const std::string &name_bytes) const;
|
||||
|
||||
std::vector<VarDescBind *> AllVars() const;
|
||||
|
||||
BlockDescBind *ParentBlock() const;
|
||||
|
||||
OpDescBind *AppendOp();
|
||||
|
||||
OpDescBind *PrependOp();
|
||||
|
||||
std::vector<OpDescBind *> AllOps() const;
|
||||
|
||||
void Sync();
|
||||
|
||||
BlockDesc *RawPtr() { return desc_; }
|
||||
|
||||
private:
|
||||
ProgramDescBind *prog_; // not_own
|
||||
BlockDesc *desc_; // not_own
|
||||
bool need_update_;
|
||||
|
||||
std::deque<std::unique_ptr<OpDescBind>> ops_;
|
||||
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,133 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_desc.h"
|
||||
#include "paddle/framework/block_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
OpDesc *OpDescBind::Proto() {
|
||||
Sync();
|
||||
return &op_desc_;
|
||||
}
|
||||
|
||||
const std::vector<std::string> &OpDescBind::Input(
|
||||
const std::string &name) const {
|
||||
auto it = inputs_.find(name);
|
||||
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
|
||||
Type());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::InputNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(this->inputs_.size());
|
||||
for (auto &ipt : this->inputs_) {
|
||||
retv.push_back(ipt.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetInput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args) {
|
||||
need_update_ = true;
|
||||
inputs_[param_name] = args;
|
||||
}
|
||||
|
||||
const std::vector<std::string> &OpDescBind::Output(
|
||||
const std::string &name) const {
|
||||
auto it = outputs_.find(name);
|
||||
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
|
||||
name, Type());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::OutputNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(this->outputs_.size());
|
||||
for (auto &ipt : this->outputs_) {
|
||||
retv.push_back(ipt.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetOutput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args) {
|
||||
need_update_ = true;
|
||||
this->outputs_[param_name] = args;
|
||||
}
|
||||
|
||||
AttrType OpDescBind::GetAttrType(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return static_cast<AttrType>(it->second.which() - 1);
|
||||
}
|
||||
|
||||
std::vector<std::string> OpDescBind::AttrNames() const {
|
||||
std::vector<std::string> retv;
|
||||
retv.reserve(attrs_.size());
|
||||
for (auto &attr : attrs_) {
|
||||
retv.push_back(attr.first);
|
||||
}
|
||||
return retv;
|
||||
}
|
||||
|
||||
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
|
||||
this->attrs_[name] = v;
|
||||
need_update_ = true;
|
||||
}
|
||||
|
||||
Attribute OpDescBind::GetAttr(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int OpDescBind::GetBlockAttr(const std::string &name) const {
|
||||
auto it = attrs_.find(name);
|
||||
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
||||
return boost::get<BlockDesc *>(it->second)->idx();
|
||||
}
|
||||
|
||||
void OpDescBind::Sync() {
|
||||
if (need_update_) {
|
||||
this->op_desc_.mutable_inputs()->Clear();
|
||||
for (auto &ipt : inputs_) {
|
||||
auto *input = op_desc_.add_inputs();
|
||||
input->set_parameter(ipt.first);
|
||||
VectorToRepeated(ipt.second, input->mutable_arguments());
|
||||
}
|
||||
|
||||
this->op_desc_.mutable_outputs()->Clear();
|
||||
for (auto &opt : outputs_) {
|
||||
auto *output = op_desc_.add_outputs();
|
||||
output->set_parameter(opt.first);
|
||||
VectorToRepeated(opt.second, output->mutable_arguments());
|
||||
}
|
||||
|
||||
this->op_desc_.mutable_attrs()->Clear();
|
||||
for (auto &attr : attrs_) {
|
||||
auto *attr_desc = op_desc_.add_attrs();
|
||||
attr_desc->set_name(attr.first);
|
||||
attr_desc->set_type(
|
||||
static_cast<framework::AttrType>(attr.second.which() - 1));
|
||||
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
|
||||
}
|
||||
|
||||
need_update_ = false;
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/attribute.h"
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class BlockDescBind;
|
||||
|
||||
class OpDescBind {
|
||||
public:
|
||||
OpDesc *Proto();
|
||||
|
||||
std::string Type() const { return op_desc_.type(); }
|
||||
|
||||
void SetType(const std::string &type) { op_desc_.set_type(type); }
|
||||
|
||||
const std::vector<std::string> &Input(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> InputNames() const;
|
||||
|
||||
void SetInput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args);
|
||||
|
||||
const std::vector<std::string> &Output(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> OutputNames() const;
|
||||
|
||||
void SetOutput(const std::string ¶m_name,
|
||||
const std::vector<std::string> &args);
|
||||
|
||||
std::string DebugString() { return this->Proto()->DebugString(); }
|
||||
|
||||
bool HasAttr(const std::string &name) const {
|
||||
return attrs_.find(name) != attrs_.end();
|
||||
}
|
||||
|
||||
AttrType GetAttrType(const std::string &name) const;
|
||||
|
||||
std::vector<std::string> AttrNames() const;
|
||||
|
||||
void SetAttr(const std::string &name, const Attribute &v);
|
||||
|
||||
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
||||
|
||||
Attribute GetAttr(const std::string &name) const;
|
||||
|
||||
int GetBlockAttr(const std::string &name) const;
|
||||
|
||||
private:
|
||||
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
||||
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
||||
mutable OpDesc::Attr *attr_;
|
||||
void operator()(int v) const { attr_->set_i(v); }
|
||||
void operator()(float v) const { attr_->set_f(v); }
|
||||
void operator()(const std::string &v) const { attr_->set_s(v); }
|
||||
void operator()(bool b) const { attr_->set_b(b); }
|
||||
|
||||
void operator()(const std::vector<int> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_ints());
|
||||
}
|
||||
void operator()(const std::vector<float> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_floats());
|
||||
}
|
||||
void operator()(const std::vector<std::string> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_strings());
|
||||
}
|
||||
void operator()(const std::vector<bool> &v) const {
|
||||
VectorToRepeated(v, attr_->mutable_bools());
|
||||
}
|
||||
void operator()(BlockDesc *desc) const {
|
||||
attr_->set_block_idx(desc->idx());
|
||||
}
|
||||
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
|
||||
};
|
||||
|
||||
void Sync();
|
||||
|
||||
OpDesc op_desc_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> inputs_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> outputs_;
|
||||
std::unordered_map<std::string, Attribute> attrs_;
|
||||
|
||||
// need_update_ indicate there some local changes not be synchronized. If
|
||||
// local changes should be synchronized, need_update_ should be set to true.
|
||||
bool need_update_{false};
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,60 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/program_desc.h"
|
||||
#include "paddle/framework/block_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
using ProgDescMap =
|
||||
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
|
||||
static ProgDescMap *g_bind_map = nullptr;
|
||||
|
||||
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
|
||||
if (g_bind_map == nullptr) {
|
||||
g_bind_map = new ProgDescMap();
|
||||
}
|
||||
auto &map = *g_bind_map;
|
||||
auto &ptr = map[prog];
|
||||
|
||||
if (ptr == nullptr) {
|
||||
ptr.reset(new ProgramDescBind(prog));
|
||||
}
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
|
||||
auto *b = prog_->add_blocks();
|
||||
b->set_parent_idx(parent.ID());
|
||||
b->set_idx(prog_->blocks_size() - 1);
|
||||
blocks_.emplace_back(new BlockDescBind(this, b));
|
||||
return blocks_.back().get();
|
||||
}
|
||||
|
||||
ProgramDesc *ProgramDescBind::Proto() {
|
||||
for (auto &block : blocks_) {
|
||||
block->Sync();
|
||||
}
|
||||
return prog_;
|
||||
}
|
||||
|
||||
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) {
|
||||
prog_ = prog;
|
||||
for (auto &block : *prog->mutable_blocks()) {
|
||||
blocks_.emplace_back(new BlockDescBind(this, &block));
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class BlockDescBind;
|
||||
|
||||
class ProgramDescBind {
|
||||
public:
|
||||
static ProgramDescBind &Instance(ProgramDesc *prog);
|
||||
|
||||
ProgramDescBind(const ProgramDescBind &o) = delete;
|
||||
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
|
||||
|
||||
BlockDescBind *AppendBlock(const BlockDescBind &parent);
|
||||
|
||||
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
|
||||
|
||||
std::string DebugString() { return Proto()->DebugString(); }
|
||||
|
||||
size_t Size() const { return blocks_.size(); }
|
||||
|
||||
ProgramDesc *Proto();
|
||||
|
||||
private:
|
||||
explicit ProgramDescBind(ProgramDesc *prog);
|
||||
|
||||
// Not owned
|
||||
ProgramDesc *prog_;
|
||||
|
||||
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,36 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
|
||||
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
|
||||
}
|
||||
|
||||
void VarDescBind::SetDataType(DataType data_type) {
|
||||
desc_.mutable_lod_tensor()->set_data_type(data_type);
|
||||
}
|
||||
|
||||
std::vector<int64_t> VarDescBind::Shape() const {
|
||||
return RepeatedToVector(desc_.lod_tensor().dims());
|
||||
}
|
||||
|
||||
DataType VarDescBind::GetDataType() const {
|
||||
return desc_.lod_tensor().data_type();
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// convert between std::vector and protobuf repeated.
|
||||
template <typename T>
|
||||
inline std::vector<T> RepeatedToVector(
|
||||
const google::protobuf::RepeatedField<T> &repeated_field) {
|
||||
std::vector<T> ret;
|
||||
ret.reserve(repeated_field.size());
|
||||
std::copy(repeated_field.begin(), repeated_field.end(),
|
||||
std::back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename RepeatedField>
|
||||
inline void VectorToRepeated(const std::vector<T> &vec,
|
||||
RepeatedField *repeated_field) {
|
||||
repeated_field->Reserve(vec.size());
|
||||
for (const auto &elem : vec) {
|
||||
*repeated_field->Add() = elem;
|
||||
}
|
||||
}
|
||||
|
||||
// Specialize vector<bool>.
|
||||
template <typename RepeatedField>
|
||||
inline void VectorToRepeated(const std::vector<bool> &vec,
|
||||
RepeatedField *repeated_field) {
|
||||
repeated_field->Reserve(vec.size());
|
||||
for (auto elem : vec) {
|
||||
*repeated_field->Add() = elem;
|
||||
}
|
||||
}
|
||||
|
||||
class VarDescBind {
|
||||
public:
|
||||
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
|
||||
|
||||
VarDesc *Proto() { return &desc_; }
|
||||
|
||||
std::string Name() const { return desc_.name(); }
|
||||
|
||||
void SetShape(const std::vector<int64_t> &dims);
|
||||
|
||||
void SetDataType(DataType data_type);
|
||||
|
||||
std::vector<int64_t> Shape() const;
|
||||
|
||||
DataType GetDataType() const;
|
||||
|
||||
private:
|
||||
VarDesc desc_;
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/concat_op.h"
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(concat,
|
||||
ops::ConcatKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
concat_grad, ops::ConcatGradKernel<paddle::platform::GPUPlace, float>);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue