commit
052d1d16ee
@ -1,6 +1,8 @@
|
||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
||||
target_link_libraries(paddle_go_optimizer stdc++ m)
|
||||
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
||||
if(WITH_TESTING)
|
||||
# TODO: add unit test
|
||||
#add_subdirectory(test)
|
||||
# FIXME: this test requires pserver which is not managed by the test
|
||||
# we need some kind of e2e testing machanism.
|
||||
# add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@ -0,0 +1,119 @@
|
||||
#pragma once
|
||||
|
||||
#include <boost/variant.hpp>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
|
||||
std::vector<float>, std::vector<std::string>>
|
||||
Attribute;
|
||||
typedef std::unordered_map<std::string, Attribute> AttributeMap;
|
||||
|
||||
// check whether a value(attribute) fit a certain limit
|
||||
template <typename T>
|
||||
class LargerThanChecker {
|
||||
public:
|
||||
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
|
||||
void operator()(T& value) const {
|
||||
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
|
||||
}
|
||||
|
||||
private:
|
||||
T lower_bound_;
|
||||
};
|
||||
|
||||
// we can provide users more common Checker, like 'LessThanChecker',
|
||||
// 'BetweenChecker'...
|
||||
|
||||
template <typename T>
|
||||
class DefaultValueSetter {
|
||||
public:
|
||||
DefaultValueSetter(T default_value) : default_value_(default_value) {}
|
||||
void operator()(T& value) const { value = default_value_; }
|
||||
|
||||
private:
|
||||
T default_value_;
|
||||
};
|
||||
|
||||
// check whether a certain attribute fit its limits
|
||||
// an attribute can have more than one limits
|
||||
template <typename T>
|
||||
class TypedAttrChecker {
|
||||
typedef std::function<void(T&)> ValueChecker;
|
||||
|
||||
public:
|
||||
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
|
||||
|
||||
TypedAttrChecker& LargerThan(const T& lower_bound) {
|
||||
value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// we can add more common limits, like LessThan(), Between()...
|
||||
|
||||
TypedAttrChecker& SetDefault(const T& default_value) {
|
||||
PADDLE_ENFORCE(default_value_setter_.empty(),
|
||||
"%s can't have more than one default value!", attr_name_);
|
||||
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// allow users provide their own checker
|
||||
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
|
||||
value_checkers_.push_back(checker);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void operator()(AttributeMap& attr_map) const {
|
||||
if (!attr_map.count(attr_name_)) {
|
||||
// user do not set this attr
|
||||
PADDLE_ENFORCE(!default_value_setter_.empty(),
|
||||
"Attribute '%s' is required!", attr_name_);
|
||||
// default_value_setter_ has no more than one element
|
||||
T val;
|
||||
(default_value_setter_[0])(val);
|
||||
attr_map[attr_name_] = val;
|
||||
}
|
||||
Attribute& attr = attr_map.at(attr_name_);
|
||||
T& attr_value = boost::get<T>(attr);
|
||||
for (const auto& checker : value_checkers_) {
|
||||
checker(attr_value);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string attr_name_;
|
||||
std::vector<ValueChecker> value_checkers_;
|
||||
std::vector<ValueChecker> default_value_setter_;
|
||||
};
|
||||
|
||||
// check whether op's all attributes fit their own limits
|
||||
class OpAttrChecker {
|
||||
typedef std::function<void(AttributeMap&)> AttrChecker;
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
|
||||
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
|
||||
AttrChecker& checker = attr_checkers_.back();
|
||||
return *(checker.target<TypedAttrChecker<T>>());
|
||||
}
|
||||
|
||||
void Check(AttributeMap& attr_map) const {
|
||||
for (const auto& checker : attr_checkers_) {
|
||||
checker(attr_map);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<AttrChecker> attr_checkers_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,122 @@
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
TEST(OpRegistry, CreateOp) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("scale");
|
||||
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||
attr->set_f(3.3);
|
||||
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OpRegistry, IllegalAttr) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("scale");
|
||||
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||
attr->set_f(-2.0);
|
||||
|
||||
bool caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "larger_than check fail";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
}
|
||||
|
||||
TEST(OpRegistry, DefaultValue) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("cos_sim");
|
||||
op_desc.add_inputs("aa");
|
||||
op_desc.add_outputs("bb");
|
||||
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
float default_value = 1.0;
|
||||
std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OpRegistry, CustomChecker) {
|
||||
paddle::framework::OpDesc op_desc;
|
||||
op_desc.set_type("my_test_op");
|
||||
op_desc.add_inputs("ii");
|
||||
op_desc.add_outputs("oo");
|
||||
|
||||
// attr 'test_attr' is not set
|
||||
bool caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "Attribute 'test_attr' is required!";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
|
||||
// set 'test_attr' set to an illegal value
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("test_attr");
|
||||
attr->set_type(paddle::framework::AttrType::INT);
|
||||
attr->set_i(3);
|
||||
caught = false;
|
||||
try {
|
||||
paddle::framework::OpBase* op __attribute__((unused)) =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
} catch (paddle::framework::EnforceNotMet err) {
|
||||
caught = true;
|
||||
std::string msg = "'test_attr' must be even!";
|
||||
const char* err_msg = err.what();
|
||||
for (size_t i = 0; i < msg.length(); ++i) {
|
||||
ASSERT_EQ(err_msg[i], msg[i]);
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(caught);
|
||||
|
||||
// set 'test_attr' set to a legal value
|
||||
op_desc.mutable_attrs()->Clear();
|
||||
attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("test_attr");
|
||||
attr->set_type(paddle::framework::AttrType::INT);
|
||||
attr->set_i(4);
|
||||
paddle::framework::OpBase* op =
|
||||
paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
std::string debug_str = op->Run();
|
||||
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
|
||||
ASSERT_EQ(str.size(), debug_str.size());
|
||||
for (size_t i = 0; i < debug_str.length(); ++i) {
|
||||
ASSERT_EQ(debug_str[i], str[i]);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,159 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/enforce.h"
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
#include "paddle/platform/cuda.h"
|
||||
#include "paddle/platform/dynload/cublas.h"
|
||||
#include "paddle/platform/dynload/cudnn.h"
|
||||
#include "paddle/platform/dynload/curand.h"
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
#include "paddle/platform/place.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
class DeviceContext {
|
||||
public:
|
||||
virtual ~DeviceContext() {}
|
||||
};
|
||||
|
||||
class CPUDeviceContext : public DeviceContext {};
|
||||
|
||||
#ifndef PADDLE_ONLY_CPU
|
||||
class GPUPlaceGuard {
|
||||
public:
|
||||
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
|
||||
if (previous_ != new_place) {
|
||||
paddle::platform::SetDeviceId(new_place.device);
|
||||
}
|
||||
}
|
||||
|
||||
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
|
||||
|
||||
private:
|
||||
GPUPlace previous_;
|
||||
};
|
||||
|
||||
class CUDADeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
|
||||
GPUPlaceGuard guard(gpu_place_);
|
||||
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
|
||||
"cudaStreamCreate failed");
|
||||
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
|
||||
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
|
||||
}
|
||||
|
||||
void Wait() {
|
||||
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
|
||||
"cudaStreamSynchronize failed");
|
||||
}
|
||||
|
||||
cudaStream_t stream() { return stream_; }
|
||||
|
||||
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
|
||||
|
||||
cublasHandle_t cublas_handle() {
|
||||
if (!blas_handle_) {
|
||||
GPUPlaceGuard guard(gpu_place_);
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
|
||||
CUBLAS_STATUS_SUCCESS,
|
||||
"cublasCreate failed");
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
|
||||
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
|
||||
"cublasSetStream failed");
|
||||
}
|
||||
return blas_handle_;
|
||||
}
|
||||
|
||||
cudnnHandle_t cudnn_handle() {
|
||||
if (!dnn_handle_) {
|
||||
GPUPlaceGuard guard(gpu_place_);
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
|
||||
CUDNN_STATUS_SUCCESS,
|
||||
"cudnnCreate failed");
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
|
||||
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
|
||||
"cudnnSetStream failed");
|
||||
}
|
||||
return dnn_handle_;
|
||||
}
|
||||
|
||||
curandGenerator_t curand_generator() {
|
||||
if (!rand_generator_) {
|
||||
GPUPlaceGuard guard(gpu_place_);
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
|
||||
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
|
||||
CURAND_STATUS_SUCCESS,
|
||||
"curandCreateGenerator failed");
|
||||
PADDLE_ENFORCE(
|
||||
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
|
||||
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
|
||||
"curandSetPseudoRandomGeneratorSeed failed");
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
|
||||
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
|
||||
"curandSetStream failed");
|
||||
}
|
||||
return rand_generator_;
|
||||
}
|
||||
|
||||
~CUDADeviceContext() {
|
||||
Wait();
|
||||
if (blas_handle_) {
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
|
||||
CUBLAS_STATUS_SUCCESS,
|
||||
"cublasDestroy failed");
|
||||
}
|
||||
|
||||
if (dnn_handle_) {
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
|
||||
CUDNN_STATUS_SUCCESS,
|
||||
"cudnnDestroy failed");
|
||||
}
|
||||
|
||||
if (rand_generator_) {
|
||||
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
|
||||
rand_generator_) == CURAND_STATUS_SUCCESS,
|
||||
"curandDestroyGenerator failed");
|
||||
}
|
||||
|
||||
delete eigen_stream_;
|
||||
delete eigen_device_;
|
||||
|
||||
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
|
||||
"cudaStreamDestroy failed");
|
||||
}
|
||||
|
||||
private:
|
||||
GPUPlace gpu_place_;
|
||||
cudaStream_t stream_;
|
||||
|
||||
Eigen::CudaStreamDevice* eigen_stream_;
|
||||
Eigen::GpuDevice* eigen_device_;
|
||||
|
||||
cublasHandle_t blas_handle_{nullptr};
|
||||
|
||||
cudnnHandle_t dnn_handle_{nullptr};
|
||||
|
||||
int random_seed_;
|
||||
curandGenerator_t rand_generator_{nullptr};
|
||||
};
|
||||
#endif
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue