Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into unsqueeze_op

guochaorong-patch-1
chenweihang 7 years ago
commit 79333fa7b8

@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)

@ -18,6 +18,8 @@ learning to many products at Baidu.
Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Lastest PaddlePaddle Version: [Fluid](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid)
## Features
- **Flexibility**

@ -83,6 +83,7 @@ else()
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
endif()
if(WITH_SYSTEM_BLAS)
find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS})
find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
@ -96,6 +97,7 @@ if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
endif()
endif()
if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND)
set(CBLAS_FOUND ON)

@ -52,7 +52,7 @@ In `trainer_internal.cpp:L93 trainOneBatch`:
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
In `legacy/trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {

@ -339,7 +339,7 @@ If you are creating a new file for the test, such as :code:`paddle/legacy/gserve
Implement Python Wrapper
========================
Implementing Python wrapper allows us to use the added layer in configuration files. All the Python wrappers are in file :code:`python/paddle/trainer/config_parser.py`. An example of the Python wrapper for fully connected layer is listed below. It has the following steps:
Implementing Python wrapper allows us to use the added layer in configuration files. All the Python wrappers are in file :code:`python/paddle/legacy/trainer/config_parser.py`. An example of the Python wrapper for fully connected layer is listed below. It has the following steps:
- Use :code:`@config_layer('fc')` at the decorator for all the Python wrapper class. :code:`fc` is the identifier of the layer.
- Implements :code:`__init__` constructor function.

@ -1,24 +1,24 @@
if(NOT WITH_FLUID_ONLY)
add_subdirectory(legacy/cuda)
add_subdirectory(legacy/function)
add_subdirectory(utils)
add_subdirectory(legacy/utils)
add_subdirectory(legacy/math)
add_subdirectory(legacy/gserver)
add_subdirectory(legacy/parameter)
if(MOBILE_INFERENCE)
add_subdirectory(capi)
add_subdirectory(legacy/capi)
else()
add_subdirectory(legacy/pserver)
add_subdirectory(trainer)
add_subdirectory(legacy/trainer)
add_subdirectory(scripts)
if(WITH_C_API)
add_subdirectory(capi)
add_subdirectory(legacy/capi)
endif()
if(WITH_SWIG_PY)
add_subdirectory(api)
add_subdirectory(legacy/api)
endif()
endif()
endif()

@ -124,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) {
ordered_ctxes.emplace(p.second, p.first);
}
for (auto &p : ordered_ctxes) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.second).device),
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}

@ -13,9 +13,9 @@
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
@ -92,9 +92,7 @@ class OpHandleBase {
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctxes_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;

@ -54,8 +54,7 @@ struct ReduceLoDTensor {
inline void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash> &dev_ctxes,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selecte_rows) {
PADDLE_ENFORCE(!src_selecte_rows_.empty());

@ -46,9 +46,16 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() {
::paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>()
->SendComplete();
void Executor::BeginPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendBeginPass();
}
void Executor::EndPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
}
#endif

@ -46,9 +46,14 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE
/*
* Sending signal to pserver to mark current trainer stop.
* Sending signal to pserver to mark current pass started.
*/
void Complete();
void BeginPass();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif
/* @Brief

@ -76,13 +76,9 @@ class OpRegistry {
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) {
std::string library(library_type);
std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") {
@ -91,8 +87,20 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout),
StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
}
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar {
}
};
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctorEx;
template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar {
public:
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
func;
func(op_type, library_type);
}
};
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
DataTypeAndKernelType...> {
void operator()(const char* op_type, const char* library_type) const {}
};
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
DataTypeAndKernelType...> {
using Functor =
typename std::tuple_element<I + 1,
std::tuple<DataTypeAndKernelType...>>::type;
using T =
typename std::tuple_element<I,
std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const {
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
DataTypeAndKernelType...>
func;
func(op_type, library_type);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to

@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.

@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
class OperatorWithKernel : public OperatorBase {
public:
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
OpKernelType::Hash>;
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)

@ -35,10 +35,20 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
void GRPCClient::SendComplete() {
void GRPCClient::SendBeginPass() {
for (auto& it : channels_) {
this->AsyncSendComplete(it.first);
VLOG(3) << "send begin pass to: " << it.first;
this->AsyncSendBeginPass(it.first);
}
this->Wait();
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
}
this->Wait();
}
GRPCClient::~GRPCClient() {
@ -226,19 +236,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++;
}
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
req.set_varname(BEGIN_PASS_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {

@ -77,12 +77,13 @@ class BaseProcessor {
context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true);
if (time_out) {
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
std::chrono::system_clock::now() +
std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
}
}
virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;
void SendComplete() override;
void SendBeginPass() override;
void SendEndPass() override;
protected:
void InitImpl() override;
@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private:

@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
rpc_server_->DecreaseClientNum();
} else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv begin pass message";
rpc_server_->WaitCond(kRequestSend);
rpc_server_->BeginPass();
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
rpc_server_->WaitCond(kRequestSend);
}
VLOG(3) << "sync: processing received var: " << varname;
if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname;
@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
}
return true;
}

@ -60,10 +60,17 @@ class RPCClient {
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
virtual void SendComplete() = 0;
virtual void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void Wait() = 0;

@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load());
});
VLOG(3) << "batch_barrier_: " << rpc_name << " "
@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
}
}
void RPCServer::DecreaseClientNum() {
void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
}
}
barrier_cond_.notify_all();
}

@ -43,6 +43,9 @@ class RPCServer {
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const;
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
@ -60,7 +63,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void BeginPass();
void EndPass();
void ResetBarrierCounter();
protected:

@ -115,6 +115,7 @@ class MKLDNNMemory {
template <typename T>
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");

@ -14,7 +14,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/utils/Logging.h"
#include "paddle/legacy/utils/Logging.h"
namespace paddle {
namespace operators {

@ -12,14 +12,108 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ReshapeOp : public framework::OperatorWithKernel {
public:
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(!shape.empty(),
"The shape information must be set by Attr(shape).");
if (ctx->HasInput("Shape") && ctx->IsRuntime()) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx->ShareLoD("X", /*->*/ "Out");
return;
}
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
static framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim &in_dims) {
const int64_t in_size = framework::product(in_dims);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
if (in_size > 0) {
// in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
} else {
output_shape[unk_dim_idx] = -1;
}
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
@ -107,19 +201,93 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
class ReshapeKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
framework::DDim out_dims = out->dims();
if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
}
if (!in->lod().empty()) {
PADDLE_ENFORCE_EQ(
out_dims[0], in->dims()[0],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op.");
}
bool inplace = ctx.Attr<bool>("inplace");
out->Resize(out_dims);
if (!inplace) {
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopySync(*in, ctx.GetPlace(), out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
}
};
class ReshapeGradKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data(ctx.GetPlace(), d_out->type());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
ctx.device_context().Wait();
d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
ops::ReshapeKernel<CPU, double>,
ops::ReshapeKernel<CPU, int>,
ops::ReshapeKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
ops::ReshapeGradKernel<CPU, double>,
ops::ReshapeGradKernel<CPU, int>,
ops::ReshapeGradKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
#endif

@ -1,26 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
paddle::operators::ReshapeKernel<CUDA, double>,
paddle::operators::ReshapeKernel<CUDA, int>,
paddle::operators::ReshapeKernel<CUDA, int64_t>);
REGISTER_OP_CUDA_KERNEL(reshape_grad,
paddle::operators::ReshapeGradKernel<CUDA, float>,
paddle::operators::ReshapeGradKernel<CUDA, double>,
paddle::operators::ReshapeGradKernel<CUDA, int>,
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save