add partial_concat op in contrib (#22528)
* add partial_concat, test=develop * fix the grids and blocks, test=develop * fix the Paddle_Enforce, test=develop * fix the doc of op, test=develop * fix the doc, test=develop * fix the doc of the op, test=develop * replace -1 with None, test=developrevert-22710-feature/integrated_ps_api
parent
dab5e5d8bc
commit
e136661304
@ -0,0 +1,210 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/partial_concat_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class PartialConcatOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Inputs("X").size(), 1UL,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(X) of Partial ConcatOp should not be empty."));
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of Partial ConcatOp should not be null."));
|
||||
|
||||
auto inputs_dims = ctx->GetInputsDim("X");
|
||||
PADDLE_ENFORCE_EQ(inputs_dims[0].size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Only supports 2-D array with batch size in the 1st "
|
||||
"dimension and data in the 2nd."));
|
||||
|
||||
const size_t inputs_num = inputs_dims.size();
|
||||
PADDLE_ENFORCE_GT(inputs_num, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"ShapeError: Input tensors count should > 0. But "
|
||||
"recevied inputs' length is 0."));
|
||||
if (inputs_num == 1) {
|
||||
VLOG(3) << "Warning: concat op have only one input, may waste memory";
|
||||
}
|
||||
|
||||
int64_t batch_size = -1;
|
||||
int64_t input_len = -1;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"It only supports two dimensions input now."));
|
||||
if (i == 0) {
|
||||
batch_size = inputs_dims[0][0];
|
||||
input_len = inputs_dims[0][1];
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(inputs_dims[i][0], batch_size,
|
||||
platform::errors::InvalidArgument(
|
||||
"The batch size of all inputs must be same"));
|
||||
PADDLE_ENFORCE_EQ(inputs_dims[i][1], input_len,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input length of all inputs must be same"));
|
||||
}
|
||||
}
|
||||
|
||||
int start_index = ComputeStartIndex(
|
||||
static_cast<int64_t>(ctx->Attrs().Get<int>("start_index")),
|
||||
inputs_dims[0][1]);
|
||||
int partial_len = ctx->Attrs().Get<int>("length");
|
||||
if (partial_len < 0) {
|
||||
partial_len = inputs_dims[0][1] - start_index;
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", {inputs_dims[0][0],
|
||||
static_cast<int64_t>(partial_len * inputs_num)});
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
auto inputs = ctx.MultiInput<Tensor>("X");
|
||||
auto input_data_type = framework::proto::VarType::Type(0);
|
||||
bool flag = 0;
|
||||
for (auto *input : inputs) {
|
||||
if (input->IsInitialized() && input->numel() > 0) {
|
||||
input_data_type = input->type();
|
||||
flag = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(flag, 1, platform::errors::InvalidArgument(
|
||||
"All Inputs of PartialSum OP are Empty!"));
|
||||
return framework::OpKernelType(input_data_type, ctx.GetPlace());
|
||||
}
|
||||
|
||||
framework::OpKernelType GetKernelTypeForVar(
|
||||
const std::string &var_name, const Tensor &tensor,
|
||||
const framework::OpKernelType &expected_kernel_type) const override {
|
||||
return framework::OpKernelType(expected_kernel_type.data_type_,
|
||||
tensor.place(), tensor.layout());
|
||||
}
|
||||
};
|
||||
|
||||
class PartialConcatGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
auto in_x = "X";
|
||||
auto out_x_g_n = framework::GradVarName(in_x);
|
||||
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
|
||||
|
||||
auto in_names = ctx->Inputs(in_x);
|
||||
auto out_names = ctx->Outputs(out_x_g_n);
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
in_names.size(), out_names.size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
|
||||
in_names.size(), out_x_g_n, out_names.size()));
|
||||
for (size_t i = 0; i < in_names.size(); ++i) {
|
||||
if (out_names[i] != framework::kEmptyVarName) {
|
||||
ctx->ShareLoD(in_x, out_x_g_n, i, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class PartialConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
|
||||
AddOutput("Out", "Output tensor of concat operator.");
|
||||
AddAttr<int>("start_index",
|
||||
"The start index of each instance for concatenation.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("length",
|
||||
"The length of each instance for concatenation."
|
||||
" Negative values for all elements after start_index")
|
||||
.SetDefault(-1);
|
||||
AddComment(R"DOC(
|
||||
Partial Concat Operator.
|
||||
Partial Concatenate the input tensors along the 2nd dimension.
|
||||
Only 2-D Tensor or LodTensor input is supported.
|
||||
Slice and concat can only be performed along the second dimension.
|
||||
Examples:
|
||||
Input[0] = [[1,2],[3,4]]
|
||||
Input[1] = [[5,6],[7,8]]
|
||||
start_index = 1
|
||||
length = 1
|
||||
Output = [[2,6],
|
||||
[4,8]]
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PartialConcatGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<T> Apply() const override {
|
||||
std::unique_ptr<T> op(new T());
|
||||
op->SetType("partial_concat_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
|
||||
op->SetAttr("start_index", this->GetAttr("start_index"));
|
||||
op->SetAttr("length", this->GetAttr("length"));
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(partial_concat, ops::PartialConcatOp,
|
||||
ops::PartialConcatOpMaker,
|
||||
ops::PartialConcatGradMaker<paddle::framework::OpDesc>,
|
||||
ops::PartialConcatGradMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(partial_concat_grad, ops::PartialConcatGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
partial_concat,
|
||||
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, double>,
|
||||
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
|
||||
ops::PartialConcatKernel<paddle::platform::CPUDeviceContext, int>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(partial_concat_grad,
|
||||
ops::PartialConcatGradientOpKernel<float>,
|
||||
ops::PartialConcatGradientOpKernel<int>,
|
||||
ops::PartialConcatGradientOpKernel<double>,
|
||||
ops::PartialConcatGradientOpKernel<int64_t>);
|
@ -0,0 +1,219 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <paddle/fluid/platform/device_context.h>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/malloc.h"
|
||||
#include "paddle/fluid/operators/partial_concat_op.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <class T>
|
||||
__global__ void ConcatPartialCUDAKernel(T **in, T *out, int64_t all_length,
|
||||
int64_t in_batch_len,
|
||||
int64_t start_index,
|
||||
int64_t out_batch_len,
|
||||
int64_t part_length) {
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
while (id < all_length) {
|
||||
int64_t bs_id = id / out_batch_len;
|
||||
int64_t bs_index = id % out_batch_len;
|
||||
int64_t var_id = bs_index / part_length;
|
||||
int64_t part_index = bs_index % part_length;
|
||||
int64_t in_id = start_index + part_index;
|
||||
const T *tmp = in[var_id];
|
||||
out[id] = tmp[bs_id * in_batch_len + in_id];
|
||||
id += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void ConcatPartialGradCUDAKernel(
|
||||
T **in, const T *out, int64_t all_length, int64_t in_batch_len,
|
||||
int64_t start_index, int64_t out_batch_len, int64_t part_length) {
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
while (id < all_length) {
|
||||
int64_t bs_id = id / out_batch_len;
|
||||
int64_t bs_index = id % out_batch_len;
|
||||
int64_t var_id = bs_index / part_length;
|
||||
int64_t part_index = bs_index % part_length;
|
||||
int64_t in_id = start_index + part_index;
|
||||
T *tmp = in[var_id];
|
||||
tmp[bs_id * in_batch_len + in_id] = out[id];
|
||||
id += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class PartialConcatOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto in_vars = ctx.MultiInput<Tensor>("X");
|
||||
Tensor *out = ctx.Output<Tensor>("Out");
|
||||
PADDLE_ENFORCE_EQ(in_vars[0] != nullptr, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input of partial concat should not be null."));
|
||||
|
||||
auto input_dim = in_vars[0]->dims();
|
||||
PADDLE_ENFORCE_EQ(input_dim.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Only supports 2-D array with batch size in the 1st "
|
||||
"dimension and data in the 2nd."));
|
||||
auto in_size = input_dim[1];
|
||||
// may be negative
|
||||
auto start_index = ctx.Attr<int>("start_index");
|
||||
start_index = ComputeStartIndex(start_index, in_size);
|
||||
|
||||
auto partial_len = ctx.Attr<int>("length");
|
||||
if (partial_len < 0) {
|
||||
partial_len = in_size - start_index;
|
||||
}
|
||||
|
||||
int in_num = in_vars.size();
|
||||
int batch_size = input_dim[0];
|
||||
int out_batch_len = partial_len * in_num;
|
||||
int all_length = batch_size * out_batch_len;
|
||||
|
||||
constexpr size_t theory_sm_threads = 1024;
|
||||
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto stream = dev_ctx.stream();
|
||||
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
auto sm_count = max_threads / theory_sm_threads;
|
||||
size_t tile_size = 0;
|
||||
int grids;
|
||||
int blocks;
|
||||
auto ComputeKernelParameter = [&](size_t length) {
|
||||
if (length >= max_threads)
|
||||
tile_size = 1024;
|
||||
else if (length < max_threads && length > sm_count * 128)
|
||||
tile_size = 512;
|
||||
else if (length <= sm_count * 128)
|
||||
tile_size = 256;
|
||||
grids = CEIL_DIV(length, tile_size);
|
||||
blocks = tile_size;
|
||||
};
|
||||
|
||||
auto place = ctx.GetPlace();
|
||||
T *out_data = out->mutable_data<T>(place);
|
||||
|
||||
std::vector<const T *> in_data;
|
||||
for (int i = 0; i < in_num; ++i)
|
||||
in_data.emplace_back(in_vars[i]->data<T>());
|
||||
|
||||
auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *));
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
|
||||
tmp_in_array->ptr(), platform::CPUPlace(),
|
||||
reinterpret_cast<void *>(in_data.data()),
|
||||
in_data.size() * sizeof(T *), dev_ctx.stream());
|
||||
|
||||
T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
|
||||
ComputeKernelParameter(all_length);
|
||||
ConcatPartialCUDAKernel<T><<<grids, blocks, 0, stream>>>(
|
||||
in_array_data, out->data<T>(), all_length, in_size, start_index,
|
||||
out_batch_len, partial_len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PartialConcatGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
auto outs = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input of partial concat should not be null."));
|
||||
// all parameters
|
||||
auto batch_size = ins[0]->dims()[0];
|
||||
auto in_size = ins[0]->dims()[1];
|
||||
// may be negative
|
||||
auto start_index = ctx.Attr<int>("start_index");
|
||||
start_index = ComputeStartIndex(start_index, in_size);
|
||||
auto partial_len = ctx.Attr<int>("length");
|
||||
if (partial_len < 0) partial_len = in_size - start_index;
|
||||
|
||||
auto in_num = ins.size();
|
||||
auto grad_batch_len = partial_len * in_num;
|
||||
auto all_length = grad_batch_len * batch_size;
|
||||
// initialize
|
||||
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
for (size_t i = 0; i < outs.size(); ++i) {
|
||||
outs[i]->mutable_data<T>(ctx.GetPlace());
|
||||
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
|
||||
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
||||
}
|
||||
|
||||
constexpr size_t theory_sm_threads = 1024;
|
||||
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto stream = dev_ctx.stream();
|
||||
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
auto sm_count = max_threads / theory_sm_threads;
|
||||
size_t tile_size = 0;
|
||||
int grids;
|
||||
int blocks;
|
||||
auto ComputeKernelParameter = [&](size_t length) {
|
||||
if (length >= max_threads)
|
||||
tile_size = 1024;
|
||||
else if (length < max_threads && length > sm_count * 128)
|
||||
tile_size = 512;
|
||||
else if (length <= sm_count * 128)
|
||||
tile_size = 256;
|
||||
grids = CEIL_DIV(length, tile_size);
|
||||
blocks = tile_size;
|
||||
};
|
||||
|
||||
std::vector<const T *> out_data;
|
||||
for (size_t i = 0; i < in_num; ++i) {
|
||||
out_data.emplace_back(outs[i]->data<T>());
|
||||
}
|
||||
auto tmp_out_array = memory::Alloc(dev_ctx, out_data.size() * sizeof(T *));
|
||||
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
|
||||
tmp_out_array->ptr(), platform::CPUPlace(),
|
||||
reinterpret_cast<void *>(out_data.data()),
|
||||
out_data.size() * sizeof(T *), dev_ctx.stream());
|
||||
|
||||
T **out_grad_data = reinterpret_cast<T **>(tmp_out_array->ptr());
|
||||
ComputeKernelParameter(all_length);
|
||||
ConcatPartialGradCUDAKernel<T><<<grids, blocks, 0, stream>>>(
|
||||
out_grad_data, out_grad->data<T>(), all_length, in_size, start_index,
|
||||
grad_batch_len, partial_len);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(partial_concat, ops::PartialConcatOpCUDAKernel<float>,
|
||||
ops::PartialConcatOpCUDAKernel<double>,
|
||||
ops::PartialConcatOpCUDAKernel<int>,
|
||||
ops::PartialConcatOpCUDAKernel<int64_t>,
|
||||
ops::PartialConcatOpCUDAKernel<plat::float16>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(partial_concat_grad,
|
||||
ops::PartialConcatGradOpCUDAKernel<float>,
|
||||
ops::PartialConcatGradOpCUDAKernel<double>,
|
||||
ops::PartialConcatGradOpCUDAKernel<int>,
|
||||
ops::PartialConcatGradOpCUDAKernel<int64_t>,
|
||||
ops::PartialConcatGradOpCUDAKernel<plat::float16>);
|
@ -0,0 +1,127 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/concat_and_split.h"
|
||||
#include "paddle/fluid/operators/strided_memcpy.h"
|
||||
#include "paddle/fluid/operators/utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
static inline int64_t ComputeStartIndex(int64_t start_index, int64_t size) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
start_index >= -size && start_index < size, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The start_index is expected to be in range of [%d, %d), but got %d",
|
||||
-size, size, start_index));
|
||||
if (start_index < 0) {
|
||||
start_index += size;
|
||||
}
|
||||
return start_index;
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PartialConcatKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
||||
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input of partial concat should not be null."));
|
||||
|
||||
auto input_dim = ins[0]->dims();
|
||||
PADDLE_ENFORCE_EQ(input_dim.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"Only supports 2-D array with batch size in the 1st "
|
||||
"dimension and data in the 2nd."));
|
||||
auto in_size = input_dim[1];
|
||||
|
||||
// may be negative
|
||||
auto start_index = ctx.Attr<int>("start_index");
|
||||
start_index = ComputeStartIndex(start_index, in_size);
|
||||
|
||||
auto partial_len = ctx.Attr<int>("length");
|
||||
if (partial_len < 0) {
|
||||
partial_len = in_size - start_index;
|
||||
}
|
||||
|
||||
int batch = input_dim[0];
|
||||
int out_size = partial_len * ins.size();
|
||||
out->Resize({batch, out_size});
|
||||
auto place = ctx.GetPlace();
|
||||
T* out_data = out->mutable_data<T>(place);
|
||||
|
||||
for (size_t i = 0; i < ins.size(); ++i) {
|
||||
for (int j = 0; j < batch; ++j) {
|
||||
const T* in_data = ins[i]->data<T>();
|
||||
memcpy(out_data + out_size * j + partial_len * i,
|
||||
in_data + in_size * j + start_index, partial_len * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PartialConcatGradientOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
|
||||
auto outs =
|
||||
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input of partial concat should not be null."));
|
||||
// all parameters
|
||||
auto batch_size = ins[0]->dims()[0];
|
||||
auto in_size = ins[0]->dims()[1];
|
||||
// may be negative
|
||||
auto start_index = ctx.Attr<int>("start_index");
|
||||
start_index = ComputeStartIndex(start_index, in_size);
|
||||
auto partial_len = ctx.Attr<int>("length");
|
||||
if (partial_len < 0) partial_len = in_size - start_index;
|
||||
|
||||
auto in_num = ins.size();
|
||||
auto grad_batch_len = partial_len * in_num;
|
||||
auto all_length = grad_batch_len * batch_size;
|
||||
|
||||
// initialize
|
||||
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
|
||||
.eigen_device();
|
||||
for (size_t i = 0; i < outs.size(); ++i) {
|
||||
outs[i]->mutable_data<T>(ctx.GetPlace());
|
||||
auto dxt = framework::EigenVector<T>::Flatten(*outs[i]);
|
||||
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
||||
}
|
||||
|
||||
auto* out_grad_t = out_grad->data<T>();
|
||||
for (size_t id = 0; id < all_length; id += partial_len) {
|
||||
int bs_id = id / grad_batch_len;
|
||||
int bs_index = id % grad_batch_len;
|
||||
int var_id = bs_index / partial_len;
|
||||
auto* out_t = outs[var_id]->data<T>();
|
||||
memcpy(out_t + bs_id * in_size + start_index, out_grad_t + id,
|
||||
partial_len * sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import random
|
||||
import six
|
||||
|
||||
|
||||
def np_partial_concat(inputs, start, length):
|
||||
assert (len(inputs[0].shape) == 2)
|
||||
size = inputs[0].shape[1]
|
||||
assert (start >= -size and start < size)
|
||||
|
||||
if start < 0:
|
||||
start += size
|
||||
if length < 0:
|
||||
length = size - start
|
||||
assert (size >= start + length)
|
||||
|
||||
elems = []
|
||||
for elem in inputs:
|
||||
assert (elem.shape == inputs[0].shape)
|
||||
elems.append(elem[:, start:start + length])
|
||||
res = np.concatenate(elems, axis=1)
|
||||
return np.concatenate(elems, axis=1)
|
||||
|
||||
|
||||
class TestPartialConcatOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "partial_concat"
|
||||
self.init_kernel_type()
|
||||
self.init_para()
|
||||
self.var_names = [
|
||||
'x' + str(num) for num in six.moves.range(self.var_num)
|
||||
]
|
||||
self.vars = [np.random.random((self.batch_size, self.column)).astype(self.dtype)\
|
||||
for num in six.moves.range(self.var_num) ]
|
||||
self.inputs = {'X': list(zip(self.var_names, self.vars))}
|
||||
self.attrs = {'start_index': self.start_index, 'length': self.length}
|
||||
y = np_partial_concat(self.vars[:], self.start_index, self.length)
|
||||
self.outputs = {'Out': y}
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.dtype = np.float64
|
||||
|
||||
def init_para(self):
|
||||
self.batch_size = random.randint(10, 20)
|
||||
self.column = random.randint(101, 200)
|
||||
self.start_index = random.randint(0, self.column - 1)
|
||||
self.length = -1
|
||||
self.var_num = random.randint(1, 3)
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
for var_name in self.var_names:
|
||||
self.check_grad([var_name], 'Out')
|
||||
|
||||
|
||||
class TestPartialConcatOp2(TestPartialConcatOp):
|
||||
def init_para(self):
|
||||
self.batch_size = random.randint(1, 10)
|
||||
self.column = random.randint(101, 200)
|
||||
self.start_index = -5
|
||||
self.length = -1
|
||||
self.var_num = 3
|
||||
|
||||
|
||||
class TestPartialConcatOp3(TestPartialConcatOp):
|
||||
def init_para(self):
|
||||
self.batch_size = random.randint(1, 10)
|
||||
self.column = random.randint(101, 200)
|
||||
self.start_index = 10
|
||||
self.length = 20
|
||||
self.var_num = 2
|
||||
|
||||
|
||||
class TestPartialConcatOp4(TestPartialConcatOp):
|
||||
def init_para(self):
|
||||
self.batch_size = random.randint(1, 10)
|
||||
self.column = random.randint(101, 200)
|
||||
self.start_index = -1
|
||||
self.length = -1
|
||||
self.var_num = 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue