Merge branch 'develop' into feature/middle_level_net_api

cblas_new
Yu Yang 8 years ago
commit 0467cd2dfe

@ -75,10 +75,11 @@ snapshot to a model will be a TODO for future.
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
etcd, trainer ID is a randomly generated UUID, the trainer will
contact the master server requesting to save the model, and find out
if itself is elected. When the master server is not used, unique
trainer IDs will be given by the administrator, the trainer whose ID
is "0" is elected to save the model.
### Model Save Path

@ -49,6 +49,7 @@ message AttrProto {
message VarProto {
required string name = 1;
required string comment = 2;
required bool is_tensor = 3;
};
message OpProto {

@ -311,3 +311,13 @@ Paddle二进制在运行时捕获了浮点数异常只要出现浮点数异
* 训练数据有问题,导致参数收敛到了一些奇异的情况。或者输入数据尺度过大,有些特征的取值达到数百万,这时进行矩阵乘法运算就可能导致浮点数溢出。
主要的解决办法是减小学习律或者对数据进行归一化处理。
15. 编译安装后执行 import paddle.v2 as paddle 报ImportError: No module named v2
------------------------------------------------------------------------
先查看一下是否曾经安装过paddle v1版本有的话需要先卸载
pip uninstall py_paddle paddle
然后安装paddle的python环境, 在build目录下执行
pip install python/dist/paddle*.whl && pip install ../paddle/dist/py_paddle*.whl

@ -59,9 +59,13 @@ func main() {
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.")
} else {
log.Errorf("Fetch checkpoint failed, %s", err)
}
}
}
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err)

@ -22,6 +22,9 @@ package main
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import "C"
@ -33,7 +36,6 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
@ -65,32 +67,32 @@ func remove(client C.paddle_master_client) *master.Client {
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
endpoints := strings.Split(p, ",")
c, err := master.NewClient(
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
master.WithBuffer(bufSize),
)
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
if err != nil {
panic(err)
}
return add(c)
}
@ -117,9 +119,10 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:error
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
@ -143,6 +146,29 @@ func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
return C.int(size)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so

@ -16,10 +16,12 @@ package master
import (
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
@ -27,6 +29,7 @@ import (
type Client struct {
conn *connection.Conn
ch chan record
initChOnce sync.Once
}
type record struct {
@ -34,24 +37,83 @@ type record struct {
err error
}
// NewClient creates a new Client.
// WithBuffer sets the client to buffer the training record.
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
func WithBuffer(bufSize int) func(*Client) error {
return func(c *Client) error {
if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh)
go c.getRecords()
return c
})
return nil
}
}
// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
return func(c *Client) error {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
return nil
}
}
// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
return func(c *Client) error {
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: timeout,
})
if err != nil {
return err
}
ch := make(chan string, 1)
a, err := GetKey(cli, DefaultAddrPath, timeout)
if err != nil {
return err
}
if a != "" {
// Master is registered, send to the master address
// channel.
ch <- a
}
go watchKey(cli, DefaultAddrPath, ch)
go c.monitorMaster(ch)
return nil
}
}
// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
c := &Client{}
c.conn = connection.New()
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, err
}
}
return c, nil
}
func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// getTask call.
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
@ -146,6 +208,20 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch
return r.r, r.err
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
var need bool
err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
return need, err
}

@ -87,9 +87,11 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
if err != nil {
panic(err)
}
err = c.SetDataset([]string{path})
if err != nil {
panic(err)

@ -158,8 +158,8 @@ func (e *EtcdClient) Load() ([]byte, error) {
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
@ -173,8 +173,8 @@ func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
// watchKey watches the specify key and send to valChan if there is some event.
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {

@ -81,6 +81,7 @@ type Service struct {
mu sync.Mutex
initDone bool
taskQueues taskQueues
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
@ -246,7 +247,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
func (s *Service) SetDataset(globPaths []string, _ *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
@ -330,7 +331,7 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
func (s *Service) GetTask(_ int, task *Task) error {
select {
case <-s.ready:
}
@ -380,7 +381,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
func (s *Service) TaskFinished(taskID int, _ *int) error {
select {
case <-s.ready:
}
@ -415,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
select {
case <-s.ready:
}
@ -432,3 +433,42 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
s.processFailedTask(t, meta.Epoch)
return nil
}
// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
TrainerID string
BlockDur time.Duration
}
// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if req.TrainerID == "" {
return errors.New("trainer id is empty")
}
if s.savingTrainer == "" {
*need = true
} else {
if req.TrainerID == s.savingTrainer {
// save trainer asked to save model again
*need = true
} else {
*need = false
}
}
if *need {
s.savingTrainer = req.TrainerID
time.AfterFunc(req.BlockDur, func() {
s.mu.Lock()
s.savingTrainer = ""
s.mu.Unlock()
})
}
return nil
}

@ -127,13 +127,19 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client)
}
// paddle_begin_init_params tells trainer if it needs to init the
// parameters.
//
// returns 1 if the trainer needs to init the parameters. 0 if the
// trainer does not need to init the parameters.
//
//export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
if selected := c.BeginInitParams(); selected {
return 1
}
return C.PSERVER_OK
return 0
}
//export paddle_init_param
@ -256,17 +262,4 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
return C.PSERVER_OK
}
//export paddle_save_model
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
if err != nil {
log.Errorln(err)
return C.PSERVER_ERROR
}
return C.PSERVER_OK
}
func main() {} // Required but ignored

@ -111,9 +111,5 @@ retry:
getParams(c);
}
if (paddle_save_model(c, "/tmp/")) {
fail();
}
return 0;
}

@ -219,32 +219,6 @@ func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
return ps, nil
}
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
err := p.Call("Service.Save", path, nil)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil
}
func strHash(s string) uint32 {
h := fnv.New32a()
_, _ = h.Write([]byte(s))

@ -36,6 +36,10 @@ import (
// ElementType is the type of elements of a Parameter.
type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message.
const (
AlreadyInitialized = "pserver already initialized"
@ -103,6 +107,10 @@ func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, e
return nil, err
}
if len(v) == 0 {
return nil, ErrCheckpointNotFound
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
@ -156,7 +164,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
@ -177,7 +185,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
func (s *Service) FinishInitParams(_ int, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
@ -190,7 +198,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error {
func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
case <-s.initialized:
default:

@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)

@ -0,0 +1,115 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
int idx = var_map.at(var.name());
int begin_idx = format.empty() ? idx : format.at(idx);
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
op_->attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format")
: std::vector<int>();
const std::vector<int>& out_format =
op_->attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, in_format, IN)));
}
for (const auto& var : op_proto.outputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, out_format, OUT)));
}
}
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
bool is_grad) const {
std::string var_name = arg->proto_name_;
if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX();
}
(*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size();
auto base_it =
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out));
if (is_grad) {
for (size_t i = pre_sz; i < in_out.size(); ++i) {
in_out[i] += OperatorBase::GRAD_VAR_SUFFIX();
}
}
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap();
int in_idx = 0;
int out_idx = 0;
std::vector<int> in_format({0});
std::vector<int> out_format({0});
for (const auto& arg : arg_list_) {
// op_'s inputs_ and outputs_
if (arg->needed_in_grad_) {
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false);
}
if (arg->type_ == IN) {
// gradients of op_'s inputs_
AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true);
} else {
// gradients of op_'s outputs_
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true);
}
}
grad_op->attrs_["input_format"] = in_format;
grad_op->attrs_["output_format"] = out_format;
grad_op->in_out_idxs_.reset(grad_varmap);
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,48 @@
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class OpRegistry;
enum InOutType { IN, OUT };
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name),
type_(type),
needed_in_grad_(needed_in_grad),
begin_idx_(begin_idx),
end_idx_(end_idx) {}
std::string proto_name_;
InOutType type_;
bool needed_in_grad_;
size_t begin_idx_;
size_t end_idx_;
};
class GradOpCreator {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const std::vector<int>& format, InOutType type);
void BuildOpInOutArgList();
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
};
} // namespace framework
} // namespace paddle

@ -0,0 +1,26 @@
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(add_two);
namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}
} // namespace framework
} // namespace paddle

@ -15,14 +15,24 @@
*/
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output;

@ -100,5 +100,7 @@ class PlainNet : public Net {
}
};
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework
} // namespace paddle

@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
namespace pd = paddle::framework;
USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle {
namespace framework {
static int infer_shape_cnt = 0;
static int run_cnt = 0;
class TestOp : public pd::OperatorBase {
class TestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<pd::Scope>& scope) const override {
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt;
}
void Run(const std::shared_ptr<pd::Scope>& scope,
void Run(const std::shared_ptr<framework::Scope>& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST(OpKernel, all) {
auto net = std::make_shared<paddle::framework::PlainNet>();
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>();
@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);
auto scope = std::make_shared<pd::Scope>();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
platform::CPUDeviceContext dev_ctx;
net->InferShape(scope);
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
} // namespace framework
} // namespace paddle

@ -1,24 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
class FakeFC : public Operator {}
} // namespace framework
} // namespace paddle

@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1]
// }
optional bool temporary = 4 [default=false];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional bool ignore_gradient = 6;
}
// Op protocol message for 3rd-party language binding.
@ -105,4 +110,5 @@ message OpProto {
// The type of that Op.
required string type = 5;
}

@ -1,3 +1,17 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
@ -6,9 +20,9 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected:
void AddInput(const std::string& name, const std::string& comment,
bool multiple = false) {
bool multiple = false, bool ignore_gradient = false) {
auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient);
input->set_multiple(multiple);
if (multiple) {
SetHasMultipleInput();
}
}
void AddInputs(const std::string& name, const std::string& comment) {
AddInput(name, comment, true);
void AddInputs(const std::string& name, const std::string& comment,
bool ignore_gradient = false) {
AddInput(name, comment, true, ignore_gradient);
}
void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false) {
bool temporary = false, bool multiple = false,
bool ignore_gradient = false) {
auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name;
*output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient);
output->set_multiple(multiple);
if (multiple) {
SetHasMultipleOutput();
@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
}
void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false) {
AddOutput(name, comment, temporary, true);
bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true, ignore_gradient);
}
template <typename T>
@ -205,8 +223,8 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
*op_proto.mutable_type() = op_type;
@ -227,18 +245,24 @@ class OpRegistry {
}
}
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type);
"Operator %s cannot be found.", type);
auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
@ -274,18 +298,41 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init();
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
};
private:
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
@ -296,16 +343,6 @@ class OpRegistry {
}
}
}
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
};
template <typename OpType, typename ProtoMakerType>
@ -316,6 +353,14 @@ class OpRegisterHelper {
}
};
template <typename OpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
@ -335,6 +380,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register OperatorKernel.
*/

@ -62,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
virtual ~OperatorBase() {}
template <typename T>

@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC");
}
};
class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save