Merge remote-tracking branch 'baidu/develop' into tensor_to_EigenTensor

cblas_new
qijun 8 years ago
commit 71e2a94310

3
.gitignore vendored

@ -22,3 +22,6 @@ cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake

@ -140,6 +140,10 @@ endif(USE_NNPACK)
add_subdirectory(proto)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/optimizer)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it.
if(WITH_GOLANG)

@ -93,7 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl")
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE)
function(merge_static_libs TARGET_NAME)
@ -301,7 +301,7 @@ function(go_library TARGET_NAME)
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code
@ -309,7 +309,7 @@ function(go_library TARGET_NAME)
-o "${${TARGET_NAME}_LIB_PATH}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
# must run under GOPATH
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_dependencies(${TARGET_NAME} go_vendor)
endfunction(go_library)
@ -320,14 +320,11 @@ function(go_binary TARGET_NAME)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH}
GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
# TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary)
@ -335,15 +332,18 @@ endfunction(go_binary)
function(go_test TARGET_NAME)
set(options OPTIONAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(multiValueArgs DEPS)
cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS}
".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test)
function(proto_library TARGET_NAME)

@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(master SRC master.go DEPS paddle_go_optimizer)
go_binary(master SRC master.go)

@ -8,6 +8,7 @@ import (
"time"
"github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
@ -18,53 +19,47 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
level, err := log.ParseLevel(*logLevel)
if err != nil {
panic(err)
}
candy.Must(err)
log.SetLevel(level)
var idx int
var cp pserver.Checkpoint
var cp *pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
idx, err = e.Register()
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil {
panic(err)
log.Errorf("Fetch checkpoint failed, %s", err)
}
}
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil {
panic(err)
}
candy.Must(err)
err = rpc.Register(s)
if err != nil {
panic(err)
}
candy.Must(err)
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
panic(err)
}
candy.Must(err)
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil)
if err != nil {
panic(err)
}
candy.Must(err)
}

14
go/glide.lock generated

@ -1,8 +1,8 @@
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7
updated: 2017-06-27T14:05:48.925262819+08:00
hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-07-11T10:04:40.786745417+08:00
imports:
- name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7
version: cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages:
- auth/authpb
- clientv3
@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
@ -34,11 +36,11 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733
version: abf9c25f54453410d0c6668e519582a9e1115027
subpackages:
- unix
- name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages:
- secure/bidirule
- transform

@ -10,3 +10,4 @@ import:
version: ^1.7.4-pre
- package: github.com/sirupsen/logrus
version: ^1.0.0
- package: github.com/topicai/candy

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(master_test)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1 @@
libpaddle_go_optimizer.a

@ -1,5 +1,13 @@
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test

@ -16,7 +16,7 @@ import (
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
// PsPath is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil
}
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {

@ -1,8 +1,7 @@
package pserver
// #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>

@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strconv"
@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter.
type ElementType int
// RPC error message.
const (
// AlreadyInitialized is true if pserver is initialized
AlreadyInitialized = "pserver already initialized"
// Uninitialized is true if pserver not fully initialized
Uninitialized = "pserver not fully initialized"
AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
CheckpointMD5Failed = "checkpoint file MD5 validation failed"
)
// Supported element types
// Supported element types.
const (
Int32 ElementType = iota
UInt32
@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
// checkpointMeta saves checkpoint metadata
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
MD5 string `json:"md5"`
Timestamp int64 `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint
type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
@ -81,12 +76,53 @@ type Service struct {
optMap map[string]*optimizer
}
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
}
fn := filepath.Join(cpPath, cpMeta.UUID)
if _, err = os.Stat(fn); os.IsNotExist(err) {
return nil, err
}
content, err := ioutil.ReadFile(fn)
if err != nil {
return nil, err
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed)
}
dec := gob.NewDecoder(bytes.NewReader(content))
cp := &Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointInterval: interval,
checkpointPath: path,
client: client,
}
@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
for _, item := range *cp {
p := ParameterWithConfig{
Param: item.Param,
Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
}
return s, nil
@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock()
defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
cp := make([]parameterCheckpoint, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
var pc parameterCheckpoint
pc.Param.Name = name
pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
cpMeta.Timestamp = time.Now().UnixNano()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil {
return err
}
@ -219,7 +257,11 @@ func (s *Service) doCheckpoint() error {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(network_helper_test)
endif()

@ -8,14 +8,12 @@ add_subdirectory(gserver)
add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(scripts)
add_subdirectory(optimizer)
add_subdirectory(string)
if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()

@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)

@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head));
}
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template <int i>
HOST bool contiguous(const Dim<i>& size, const Dim<i>& stride, int mul = 1) {
if (product(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return (get<0>(stride) == contiguous_stride &&
contiguous(size.tail, stride.tail, mul * get<0>(size)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template <>
inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) {
if (get<0>(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return get<0>(stride) == contiguous_stride;
}
///\endcond
/**
* \brief Compute exclusive prefix-multiply of a Dim.
*/
@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
}
///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template <int i>
HOSTDEVICE Dim<i> contiguous_strides(const Dim<i>& size, int base = 1) {
int stride = size.head == 1 ? 0 : base;
return Dim<i>(stride, contiguous_strides(size.tail, base * size.head));
}
///\cond HIDDEN
// Base case of contiguous_strides
template <>
HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) {
int stride = size.head == 1 ? 0 : base;
return Dim<1>(stride);
}
///\endcond
/**
* Add two dimensions together
*/

@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12);
// contiguous_strides
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 0);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 10);
EXPECT_EQ(paddle::framework::get<2>(c), 0);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 0);
EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 6);
// generate from an index
auto size = paddle::framework::make_dim(4, 5, 2);
c = paddle::framework::Dim<3>(14, size);
@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c);
// contiguous check
int x = 4, y = 5, z = 2;
paddle::framework::Dim<3> sizef(x, y, z);
paddle::framework::Dim<3> stridea(1, x, x*y);
paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y);
paddle::framework::Dim<3> stridec(1, x, 2*x*y);
EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea));
EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb));
EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec));
}
TEST(Dim, Print) {

@ -147,13 +147,13 @@ class OpRegisterHelper {
}
};
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \
class __op_class##Register { \
private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \
}; \
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);
#define REGISTER_OP(type, op_class, op_maker_class) \
class op_class##Register { \
private: \
const static OpRegisterHelper<op_class, op_maker_class> reg; \
}; \
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
#type)
} // namespace framework
} // namespace paddle

@ -1,17 +1,15 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework;
namespace paddle {
namespace framework {
class CosineOp : public OperatorWithKernel {
class CosineOp : public OperatorBase {
public:
void Run(const OpRunContext* context) const override {
printf("%s\n", DebugString().c_str());
}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -28,14 +26,15 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
class MyTestOp : public OperatorWithKernel {
public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -54,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework
} // namespace paddle
@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}
@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4);
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx);
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}

@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str();
}
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework
} // namespace paddle

@ -14,44 +14,22 @@ limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {
class OperatorBase;
class DeviceContext {};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const;
Variable* Output(int index) const;
public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;
protected:
std::string Type() const { return desc_.type(); }
public:
OpDesc desc_;
@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_;
};
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
};
class OperatorWithKernel : public OperatorBase {
public:
virtual ~OperatorWithKernel() {}
struct OpKernelKey {
platform::Place place_;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};
struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const {
OpRunContext op_ctx(this, scope, dev_ctx);
Run(&op_ctx);
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}
/// when implement an Op, your should implement this function.
/// this function should be moved to OpKernel later
virtual void Run(const OpRunContext* context) const = 0;
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
};
} // namespace framework
} // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save