You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
8.4 KiB

7 years ago
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
7 years ago
7 years ago
#pragma once
7 years ago
7 years ago
#include "paddle/fluid/framework/op_registry.h"
7 years ago
#include "paddle/fluid/platform/for_range.h"
#ifdef __NVCC__
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
7 years ago
namespace paddle {
namespace operators {
class StackOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
"Number of Inputs(X) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
auto input_dims = ctx->GetInputsDim("X");
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(X) must be the same");
// Only lod of X[0] would be shared with Y
ctx->ShareLoD("X", /*->*/ "Y");
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize2int(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim("Y", framework::make_ddim(vec));
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The input of stack op.").AsDuplicable();
AddOutput("Y", "The output of stack op.");
"The axis along which all of the Inputs(X) should be stacked.")
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
template <typename VecXType, typename T>
struct StackFunctor {
HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
: x_(x), y_(y), n_(n), post_(post) {}
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
y_[idx] = x_[which_x][x_index];
VecXType x_;
T *y_;
int n_;
int post_;
7 years ago
template <typename VecDxType, typename T>
struct StackGradFunctor {
HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
: dx_(dx), dy_(dy), n_(n), post_(post) {}
HOSTDEVICE void operator()(int idx) {
int i = idx / (n_ * post_);
int which_x = idx / post_ - i * n_;
int x_index = i * post_ + idx % post_;
dx_[which_x][x_index] = dy_[idx];
VecDxType dx_;
const T *dy_;
int n_;
int post_;
template <typename DeviceContext, typename VecXType, typename T>
static inline void StackFunctorForRange(const DeviceContext &ctx,
const VecXType &x, T *y, int total_num,
int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackFunctor<VecXType, T>(x, y, n, post));
7 years ago
template <typename DeviceContext, typename VecDxType, typename T>
static inline void StackGradFunctorForRange(const DeviceContext &ctx,
const VecDxType &dx, const T *dy,
int total_num, int n, int post) {
platform::ForRange<DeviceContext> for_range(ctx, total_num);
for_range(StackGradFunctor<VecDxType, T>(dx, dy, n, post));
template <typename DeviceContext, typename T>
7 years ago
class StackKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
void Compute(const framework::ExecutionContext &ctx) const override {
auto x = ctx.MultiInput<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += (x[0]->dims().size() + 1);
int n = static_cast<int>(x.size());
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
std::vector<const T *> x_datas(n);
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
int pre = 1, post = 1;
auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
7 years ago
#ifdef __NVCC__
int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
thrust::device_vector<const T *> device_x_vec(x_datas);
auto x_data_arr =;
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
// Wait() must be called because device_x_vec may be destructed before
// kernel ends
7 years ago
auto x_data_arr =;
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
post * sizeof(T));
y_offset += post;
x_offset += post;
7 years ago
7 years ago
class StackOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
"Input(Y@Grad) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = dy_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
if (axis < 0) axis += rank;
"Number of Outputs(X@Grad) is wrong");
auto vec = framework::vectorize2int(dy_dim);
vec.erase(vec.begin() + axis);
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
7 years ago
class StackGradOpDescMaker : public framework::SingleGradOpDescMaker {
7 years ago
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
7 years ago
7 years ago
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
return op;
7 years ago
template <typename DeviceContext, typename T>
7 years ago
class StackGradKernel : public framework::OpKernel<T> {
using Tensor = framework::LoDTensor;
void Compute(const framework::ExecutionContext &ctx) const override {
auto *dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis += dy->dims().size();
int n = dy->dims()[axis];
std::vector<T *> dx_datas(n); // NOLINT
7 years ago
for (int i = 0; i < n; i++) {
7 years ago
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
7 years ago
7 years ago
auto dy_data = dy->data<T>();
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
7 years ago
int total_num = dy->numel();
int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
#ifdef __NVCC__
thrust::device_vector<T *> device_dx_vec(dx_datas);
auto dx_data_arr =;
7 years ago
auto dx_data_arr =;
7 years ago
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
7 years ago
#ifdef __NVCC__
// Wait() must be called because device_dx_vec may be destructed before
// kernel ends
7 years ago
7 years ago
} // namespace operators
} // namespace paddle