refine seq_concat

upload-readme
chengduoZH 6 years ago
parent 437debf40e
commit e7940141ce

@ -441,7 +441,10 @@ static void InitInferShapeFuncs() {
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) { for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
auto op_type = kern_pair.first; auto op_type = kern_pair.first;
auto &op_info = info_map.at(op_type); auto it = info_map.find(op_type);
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()( auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{})); "", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered. if (op_info.infer_shape_) { // infer_shape has been registered.

@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
} }
}; };

@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor; concat_grad_functor;
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis), concat_grad_functor(dev_ctx, *out_grad,
&outputs); ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
} }
} }
}; };

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
@ -24,10 +24,22 @@ namespace detail {
* and passed by `args` * and passed by `args`
*/ */
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
inline T &Ref(T *ptr, ARGS &&... args) { inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE(ptr != nullptr, args...); PADDLE_ENFORCE(ptr != nullptr, args...);
return *ptr; return *ptr;
} }
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> VectorRef(
const std::vector<T*>& vec, ARGS&&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto* ptr : vec) {
result.emplace_back(Ref(ptr, args...));
}
return result;
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -27,7 +27,7 @@ template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> { class ConcatFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int num = input.size(); int num = input.size();
@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
size_t num = outputs->size(); size_t num = outputs->size();
@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
} }
} }
}; };
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
template class ConcatFunctor<platform::CPUDeviceContext, int>; FOR_ALL_TYPES(DEFINE_FUNCTOR);
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatFunctor<platform::CPUDeviceContext, float>;
template class ConcatFunctor<platform::CPUDeviceContext, double>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators

@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -118,7 +119,7 @@ template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> { class ConcatFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int in_num = input.size(); int in_num = input.size();
@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
int out_row = 1; int out_row = 1;
@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template class ConcatFunctor<platform::CUDADeviceContext, int>; #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>; template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatFunctor<platform::CUDADeviceContext, float>; template class ConcatGradFunctor<platform::CUDADeviceContext, type>
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>; FOR_ALL_TYPES(DEFINE_FUNCTOR);
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators

@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
class ConcatFunctor { class ConcatFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output); framework::Tensor* output);
}; };
@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::LoDTensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs); int axis, std::vector<framework::Tensor*>* outputs);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)

@ -1,136 +1,100 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_concat_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SequenceConcatOp : public framework::OperatorWithKernel { class SeqConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void Make() override {
AddInput("X", "The inputs of sequence concat op").AsDuplicable();
void InferShape(framework::InferShapeContext* ctx) const override { AddOutput("Out", "The output of sequence concat op");
PADDLE_ENFORCE(ctx->HasInputs("X"), AddComment(
"Inputs(X) of SequenceConcatOp should not be null."); "Sequence Concat Op\n"
PADDLE_ENFORCE(ctx->HasOutput("Out"), "It will concat LoD tensors by its sequence information.\n"
"Output(Out) of SequenceConcatOp should not be null."); "For example:\n"
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level")); " LoD of X1 = [0, 3, 7]\n"
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); " LoD of X2 = [0, 7, 9]\n"
PADDLE_ENFORCE(level == 0UL || level == 1UL, " Result LoD is [0, (3+7), (7+9)]\n"
"The sequence_concat operator only accepts sequence " " i.e.[0, 10, 16]\n");
"or a nested sequence as its input.");
auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size();
for (size_t i = 1; i < n; ++i) {
out_dims[axis] += ins_dims[i][axis];
}
ctx->SetOutputDim("Out", out_dims);
} }
}; };
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { class SeqConcatShapeInferer : public framework::InferShapeBase {
public: public:
void Make() override { void operator()(framework::InferShapeContext *context) const override {
AddInput("X", try {
"(LodTensorArray) Input is a vector of LoDTensor, " PADDLE_ENFORCE(context->HasInputs("X"));
"each of which is a variable-length sequence or nested sequence.") PADDLE_ENFORCE(context->HasOutput("Out"));
.AsDuplicable();
AddOutput("Out", auto x_dims = context->GetInputsDim("X");
"(LoDTensor), Variable-length output of " int64_t batch_size = 0;
"sequence_concat Op."); int64_t feature_size = 0;
AddAttr<int>("axis", std::vector<int64_t> out_dims;
"(int, default 0) " for (auto &x_dim : x_dims) {
"The axis along which the inputs will be joined. " if (out_dims.empty()) {
"If axis is 0, the inputs will be joined with LoD index.") out_dims = framework::vectorize(x_dim);
.SetDefault(0); }
AddAttr<int>("level", batch_size += x_dim[0];
"(int, default 0) " if (feature_size == 0) {
"The level at which the inputs will be joined. " feature_size = framework::product(x_dim) / x_dim[0];
"If the level is 0, the inputs will be joined at the nested " } else {
"sequence level. " PADDLE_ENFORCE_EQ(
"If the level is 1, the inputs will be joined at the " feature_size, framework::product(x_dim) / x_dim[0],
"sequence level. " "Inputs of sequence concat must have same feature size");
"The level should be less than the level number of inputs.") }
.SetDefault(0); }
AddComment(R"DOC( if (batch_size < 0) {
The sequence_concat operator concatenates multiple LoDTensors. batch_size = -1; // Normalize batch size for compile time.
It only supports sequence (LoD Tensor with level number is 1) }
or a nested sequence (LoD tensor with level number is 2) as its input. out_dims[0] = batch_size;
- Case1: context->SetOutputDim("Out", framework::make_ddim(out_dims));
If the axis is other than 0(here, axis is 1 and level is 1), if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
each input should have the same LoD information and the LoD // in Kernel.
information of the output keeps the same as the input. context->ShareLoD("X", "Out");
}
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) } catch (...) {
LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) PADDLE_THROW("Unknown error");
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) }
- Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute.
The LoD information of level-1 should be same.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,2,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,4}, {0,2,5,8,11}}; Dims(Out) = (11,3,4)
- Case3:
If the axis is 0(here, level is 1).
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,4}, {0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,5,8}, {0,1,2,3,5,7,8,9,11}}; Dims(Out) = (11,3,4)
- Case4:
If the LoD number is 1, axis is 0, level is 0
LoD(x0) = {{0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,1,3,5,7}}; Dims(x1) = (7,3,4)
LoD(Out) = {{0,2,5,8,11}}; Dims(Out) = (11,3,4)
NOTE: The levels of all the inputs should be the same.
)DOC");
} }
}; };
class SequenceConcatGradOp : public framework::OperatorWithKernel { class SeqConcatGradShapeInferer : public framework::InferShapeBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; void operator()(framework::InferShapeContext *context) const override {
context->SetOutputsDim(framework::GradVarName("X"),
void InferShape(framework::InferShapeContext* ctx) const override { context->GetInputsDim("X"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"The gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"The gradient of X should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, ops::SequenceConcatOp,
ops::SequenceConcatOpMaker, REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
paddle::framework::DefaultGradOpDescMaker< op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
false> /* set false to disable empty grad */); paddle::framework::DefaultGradOpDescMaker<false>);
REGISTER_OPERATOR(sequence_concat_grad, ops::SequenceConcatGradOp); template <typename T>
REGISTER_OP_CPU_KERNEL( using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
sequence_concat, REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
REGISTER_OP_CPU_KERNEL( op::SeqConcatGradShapeInferer);
sequence_concat_grad, template <typename T>
ops::SequenceConcatGradOpKernel<paddle::platform::CPUDeviceContext, float>); using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);

@ -1,23 +1,26 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_concat_op.h"
namespace ops = paddle::operators; template <typename T>
REGISTER_OP_CUDA_KERNEL( using Kernel =
sequence_concat, paddle::operators::SeqConcatKernel<paddle::platform::CUDADeviceContext, T>;
ops::SequenceConcatOpKernel<paddle::platform::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad, template <typename T>
ops::SequenceConcatGradOpKernel< using GradKernel =
paddle::platform::CUDADeviceContext, float>); paddle::operators::SeqConcatGradKernel<paddle::platform::CUDADeviceContext,
T>;
REGISTER_OP_CUDA_KERNEL(sequence_concat_grad, GradKernel<float>,
GradKernel<double>);

File diff suppressed because it is too large Load Diff

@ -0,0 +1,45 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestSequenceConcat(OpTest):
def setUp(self):
x1 = np.random.random(size=(10, 80))
lod1 = [7, 3]
x2 = np.random.random(size=(20, 80))
lod2 = [12, 8]
out = np.concatenate((x1[0:lod1[0]], x2[0:lod2[0]], x1[lod1[0]:],
x2[lod2[0]:]))
out_lod = [19, 11]
self.op_type = "sequence_concat"
self.inputs = {'X': [("x1", (x1, [lod1])), ("x2", (x2, [lod2]))]}
self.outputs = {"Out": (out, [out_lod])}
def test_output(self):
self.check_output(1e-3)
def test_dx(self):
self.check_grad(inputs_to_check=['x1', 'x2'], output_names="Out")
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save