commit
30849d1f20
@ -1,2 +1,6 @@
|
||||
include(operators)
|
||||
register_operators()
|
||||
register_operators(EXCLUDES fusion_transpose_flatten_concat_op)
|
||||
if (WITH_GPU)
|
||||
op_library(fusion_transpose_flatten_concat_op)
|
||||
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
|
||||
endif()
|
||||
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
|
||||
"Inputs(X) of ConcatOp should be empty.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of ConcatOp should not be null.");
|
||||
|
||||
auto ins = ctx->GetInputsDim("X");
|
||||
const size_t n = ins.size();
|
||||
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
|
||||
|
||||
std::vector<int> trans_axis =
|
||||
ctx->Attrs().Get<std::vector<int>>("trans_axis");
|
||||
int flatten_axis = ctx->Attrs().Get<int>("flatten_axis");
|
||||
int concat_axis = ctx->Attrs().Get<int>("concat_axis");
|
||||
|
||||
size_t x_rank = ins[0].size();
|
||||
size_t trans_axis_size = trans_axis.size();
|
||||
PADDLE_ENFORCE_EQ(x_rank, trans_axis_size,
|
||||
"The input tensor's rank(%d) "
|
||||
"should be equal to the permutation axis's size(%d)",
|
||||
x_rank, trans_axis_size);
|
||||
|
||||
auto dims0 =
|
||||
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0]));
|
||||
std::vector<int> out_dims(dims0);
|
||||
for (size_t i = 1; i < n; i++) {
|
||||
auto dimsi =
|
||||
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[i]));
|
||||
for (int j = 0; j < static_cast<int>(dims0.size()); j++) {
|
||||
if (j == concat_axis) {
|
||||
out_dims[concat_axis] += dimsi[j];
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j],
|
||||
"After flatting, the %d-th dim should be save "
|
||||
"except the specify axis.",
|
||||
j);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (out_dims[concat_axis] < 0) {
|
||||
out_dims[concat_axis] = -1;
|
||||
}
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeFlattenConcatFusionOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensor) The input tensor, tensors with rank up to 6 are supported.")
|
||||
.AsDuplicable();
|
||||
AddOutput("Out", "(Tensor)The output tensor.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"trans_axis",
|
||||
"(vector<int>) A list of values, and the size of the list should be "
|
||||
"the same with the input tensor rank. This operator permutes the input "
|
||||
"tensor's axes according to the values given.");
|
||||
AddAttr<int>("flatten_axis",
|
||||
"(int)"
|
||||
"Indicate up to which input dimensions (exclusive) should be"
|
||||
"flattened to the outer dimension of the output. The value"
|
||||
"for axis must be in the range [0, R], where R is the rank of"
|
||||
"the input tensor. When axis = 0, the shape of the output"
|
||||
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
|
||||
"input tensor is (d_0, d_1, ... d_n).");
|
||||
AddAttr<int>("concat_axis",
|
||||
"The axis along which the input tensors will be concatenated. "
|
||||
"It should be 0 or 1, since the tensor is 2D after flatting.");
|
||||
AddComment(R"DOC(
|
||||
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_transpose_flatten_concat,
|
||||
ops::TransposeFlattenConcatFusionOp,
|
||||
ops::TransposeFlattenConcatFusionOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
@ -0,0 +1,115 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h"
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/cudnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
using CudnnDataType = platform::CudnnDataType<T>;
|
||||
|
||||
template <typename T>
|
||||
class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
auto odims = out->dims();
|
||||
|
||||
std::vector<int> trans_axis = ctx.Attr<std::vector<int>>("trans_axis");
|
||||
int flatten_axis = ctx.Attr<int>("flatten_axis");
|
||||
int concat_axis = ctx.Attr<int>("concat_axis");
|
||||
|
||||
int rank = ins[0]->dims().size();
|
||||
// use at least 4D in cudnnTransformTensor
|
||||
int max_dim = rank < 4 ? 4 : rank;
|
||||
std::vector<int> stride_x(max_dim, 0);
|
||||
std::vector<int> stride_y(max_dim, 0);
|
||||
std::vector<int> dims_y(max_dim, 0);
|
||||
|
||||
cudnnTensorDescriptor_t in_desc;
|
||||
cudnnTensorDescriptor_t out_desc;
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&in_desc));
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&out_desc));
|
||||
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto handle = dev_ctx.cudnn_handle();
|
||||
|
||||
T* odata = out->data<T>();
|
||||
for (size_t k = 0; k < ins.size(); ++k) {
|
||||
auto perm_shape = GetPermuteShape(trans_axis, ins[k]->dims());
|
||||
int osize = 1;
|
||||
auto idims = ins[k]->dims();
|
||||
for (int i = 0; i < rank; i++) {
|
||||
stride_x[i] = 1;
|
||||
for (int j = trans_axis[i] + 1; j < rank; j++) {
|
||||
stride_x[i] *= idims[j];
|
||||
}
|
||||
dims_y[i] = perm_shape[i];
|
||||
osize *= perm_shape[i];
|
||||
}
|
||||
stride_y[rank - 1] = 1;
|
||||
for (int i = rank - 2; i >= 0; i--) {
|
||||
if (((i + 1) == flatten_axis) && (concat_axis == 1)) {
|
||||
stride_y[i] = odims[1];
|
||||
} else {
|
||||
stride_y[i] = stride_y[i + 1] * perm_shape[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
// Since concat is aftern flatten, the output is 2D tensor.
|
||||
// If concat_axis is 0, each input's permutated tensor is continuous.
|
||||
// If concat_axis is 1, the stride of 0-th dim of each input's
|
||||
// permutated tensor is odims()[1].
|
||||
|
||||
for (int i = rank; i < max_dim; i++) {
|
||||
stride_x[i] = 1;
|
||||
stride_y[i] = 1;
|
||||
dims_y[i] = 1;
|
||||
}
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
|
||||
in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()));
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
|
||||
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()));
|
||||
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnTransformTensor(
|
||||
handle, CudnnDataType<T>::kOne(), in_desc,
|
||||
static_cast<const void*>(ins[k]->data<T>()),
|
||||
CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata)));
|
||||
if (concat_axis == 0) {
|
||||
odata += osize;
|
||||
} else {
|
||||
auto flat_shape = GetFlattenShape(flatten_axis, perm_shape);
|
||||
odata += flat_shape[1];
|
||||
}
|
||||
}
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(in_desc));
|
||||
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(out_desc));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(fusion_transpose_flatten_concat,
|
||||
ops::TransposeFlattenConcatFusionKernel<float>,
|
||||
ops::TransposeFlattenConcatFusionKernel<double>);
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
inline std::vector<int32_t> GetPermuteShape(const std::vector<int>& axis,
|
||||
const framework::DDim& in_dims) {
|
||||
std::vector<int32_t> out_dims(in_dims.size());
|
||||
for (size_t i = 0; i < axis.size(); i++) {
|
||||
out_dims[i] = in_dims[axis[i]];
|
||||
}
|
||||
return out_dims;
|
||||
}
|
||||
|
||||
inline std::vector<int32_t> GetFlattenShape(const int axis,
|
||||
const std::vector<int>& in_dims) {
|
||||
int64_t outer = 1, inner = 1;
|
||||
for (int i = 0; i < static_cast<int>(in_dims.size()); ++i) {
|
||||
if (i < axis) {
|
||||
outer *= in_dims[i];
|
||||
} else {
|
||||
inner *= in_dims[i];
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> out_shape(2);
|
||||
out_shape[0] = outer;
|
||||
out_shape[1] = inner;
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class TestFusionTransposeFlattenConcationOp(OpTest):
|
||||
def setUp(self):
|
||||
self.init_test_case()
|
||||
self.op_type = "fusion_transpose_flatten_concat"
|
||||
|
||||
ins = []
|
||||
flats = []
|
||||
for i in range(len(self.shapes)):
|
||||
in_shape = self.shapes[i]
|
||||
a = np.random.random(in_shape).astype("float32")
|
||||
ins.append(("x%d" % i, a))
|
||||
|
||||
b = a.transpose(self.trans_axis)
|
||||
flat_shape = (np.prod(b.shape[:self.flatten_axis]),
|
||||
np.prod(b.shape[self.flatten_axis:]))
|
||||
c = b.reshape(flat_shape)
|
||||
flats.append(c)
|
||||
out = np.concatenate(flats, axis=self.concat_axis)
|
||||
|
||||
self.inputs = {'X': ins}
|
||||
self.attrs = {
|
||||
'trans_axis': list(self.trans_axis),
|
||||
'flatten_axis': self.flatten_axis,
|
||||
'concat_axis': self.concat_axis
|
||||
}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def test_check_output(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, 1e-6)
|
||||
else:
|
||||
pass
|
||||
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 4, 17, 17), (3, 8, 7, 7), (3, 12, 5, 5)]
|
||||
self.trans_axis = (0, 2, 3, 1)
|
||||
self.flatten_axis = 1
|
||||
self.concat_axis = 1
|
||||
|
||||
|
||||
class TestCase1(TestFusionTransposeFlattenConcationOp):
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 4, 18, 17), (3, 8, 18, 7), (6, 12, 9, 5)]
|
||||
self.trans_axis = (0, 2, 3, 1)
|
||||
self.flatten_axis = 2
|
||||
self.concat_axis = 1
|
||||
|
||||
|
||||
class TestCase2(TestFusionTransposeFlattenConcationOp):
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 8, 20, 17), (3, 8, 19, 17), (3, 8, 40, 17)]
|
||||
self.trans_axis = (0, 2, 3, 1)
|
||||
self.flatten_axis = 2
|
||||
self.concat_axis = 0
|
||||
|
||||
|
||||
class TestCase3(TestFusionTransposeFlattenConcationOp):
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 8, 20, 17), (3, 8, 19, 17), (3, 8, 40, 17)]
|
||||
self.trans_axis = (0, 3, 2, 1)
|
||||
self.flatten_axis = 1
|
||||
self.concat_axis = 1
|
||||
|
||||
|
||||
class TestCase4(TestFusionTransposeFlattenConcationOp):
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 8, 9, 17), (8, 3, 9, 17), (4, 6, 9, 17)]
|
||||
self.trans_axis = (0, 2, 1, 3)
|
||||
self.flatten_axis = 3
|
||||
self.concat_axis = 1
|
||||
|
||||
|
||||
class TestCase5(TestFusionTransposeFlattenConcationOp):
|
||||
def init_test_case(self):
|
||||
self.shapes = [(3, 8, 9, 17, 2), (3, 8, 2, 17, 9), (3, 17, 9, 8, 2)]
|
||||
self.trans_axis = (0, 2, 1, 4, 3)
|
||||
self.flatten_axis = 1
|
||||
self.concat_axis = 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue