add unbind op (#23359)
* add unbind op unbind(tensor, dim=0): 说明:移除指定维后,返回一组数组,包含了沿着指定维切片后的各个切片。 tensor(Tensor) -- 输入Tensor dim(int) -- 删除的维度 示例: Input = [[1,2], [3,4], [5,6]] axis = 0 Output[0] = [1,2] Output[1] = [3,4] Output[2] = [5,6]revert-23830-2.0-beta
parent
fd9b7bdb3d
commit
eb035f24d1
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/unbind_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using framework::Tensor;
|
||||
|
||||
class UnbindOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::NotFound("Input(X) of UnbindOp is not found."));
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Outputs("Out").size(), 1UL,
|
||||
platform::errors::NotFound("Outputs(Out) of UnbindOp is not found."));
|
||||
auto in_dims = ctx->GetInputDim("X");
|
||||
auto outs_names = ctx->Outputs("Out");
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
const size_t outs_number = outs_names.size();
|
||||
auto out_dims = UnbindOutsDims(in_dims, axis);
|
||||
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
|
||||
ctx->SetOutputsDim("Out", outs_dims);
|
||||
for (size_t i = 0; i < outs_number; ++i) {
|
||||
ctx->ShareLoD("X", "Out", 0, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class UnbindOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) Input tensor of the split operator.");
|
||||
AddOutput("Out", "(Tensor) Output tensors of the unbind operator.")
|
||||
.AsDuplicable();
|
||||
AddComment(R"DOC(
|
||||
Unbind operator
|
||||
|
||||
Remove a tensor dimension.
|
||||
|
||||
Example:
|
||||
Input = [[1,2],
|
||||
[3,4],
|
||||
[5,6]]
|
||||
axis = 0
|
||||
Output[0] = [1,2]
|
||||
Output[1] = [3,4]
|
||||
Output[2] = [5,6]
|
||||
|
||||
)DOC");
|
||||
AddAttr<int>("axis",
|
||||
"(int, default 0) "
|
||||
"dimension to remove.")
|
||||
.SetDefault(0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(unbind, ops::UnbindOp, ops::UnbindOpMaker,
|
||||
ops::UnbindGradMaker<paddle::framework::OpDesc>,
|
||||
ops::UnbindGradMaker<paddle::imperative::OpBase>);
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
unbind, ops::UnbindOpKernel<plat::CPUDeviceContext, double>,
|
||||
ops::UnbindOpKernel<plat::CPUDeviceContext, float>,
|
||||
ops::UnbindOpKernel<plat::CPUDeviceContext, int64_t>,
|
||||
ops::UnbindOpKernel<plat::CPUDeviceContext, int>,
|
||||
ops::UnbindOpKernel<plat::CPUDeviceContext, plat::float16>);
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/unbind_op.h"
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
unbind, ops::UnbindOpKernel<plat::CUDADeviceContext, double>,
|
||||
ops::UnbindOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::UnbindOpKernel<plat::CUDADeviceContext, int64_t>,
|
||||
ops::UnbindOpKernel<plat::CUDADeviceContext, int>,
|
||||
ops::UnbindOpKernel<plat::CUDADeviceContext, plat::float16>);
|
@ -0,0 +1,77 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/concat_and_split.h"
|
||||
#include "paddle/fluid/operators/strided_memcpy.h"
|
||||
#include "paddle/fluid/operators/utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims,
|
||||
int axis) {
|
||||
std::vector<int> out_dims;
|
||||
axis = axis < 0 ? in_dims.size() + axis : axis;
|
||||
for (int i = 0; i < in_dims.size(); i++) {
|
||||
if (i != axis) out_dims.push_back(in_dims[i]);
|
||||
}
|
||||
return framework::make_ddim(out_dims);
|
||||
}
|
||||
template <typename DeviceContext, typename T>
|
||||
class UnbindOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
|
||||
auto in_dims = in->dims();
|
||||
|
||||
auto place = ctx.GetPlace();
|
||||
|
||||
axis = axis < 0 ? in_dims.size() + axis : axis;
|
||||
std::vector<const framework::Tensor*> shape_refer;
|
||||
for (size_t j = 0; j < outs.size(); ++j) {
|
||||
outs[j]->mutable_data<T>(ctx.GetPlace());
|
||||
shape_refer.emplace_back(outs[j]);
|
||||
}
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
math::SplitFunctor<DeviceContext, T> functor;
|
||||
functor(dev_ctx, *in, shape_refer, axis, &outs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class UnbindGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("stack");
|
||||
op->SetInput("X", this->OutputGrad("Out"));
|
||||
op->SetOutput("Y", this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import compiler, Program, program_guard, core
|
||||
|
||||
|
||||
class TestUnbindOp(OpTest):
|
||||
def initParameters(self):
|
||||
pass
|
||||
|
||||
def outReshape(self):
|
||||
pass
|
||||
|
||||
def setAxis(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self._set_op_type()
|
||||
self.dtype = self.get_dtype()
|
||||
self.axis = 0
|
||||
self.num = 3
|
||||
self.initParameters()
|
||||
#x = np.random.random((3, 2, 2)).astype(self.dtype)
|
||||
x = np.arange(12).reshape(3, 2, 2).astype(self.dtype)
|
||||
self.out = np.split(x, self.num, self.axis)
|
||||
self.outReshape()
|
||||
self.inputs = {'X': x}
|
||||
self.attrs = {'axis': self.axis}
|
||||
self.setAxis()
|
||||
self.outputs = {'Out': [('out%d' % i, self.out[i]) \
|
||||
for i in range(len(self.out))]}
|
||||
|
||||
def get_dtype(self):
|
||||
return "float64"
|
||||
|
||||
def _set_op_type(self):
|
||||
self.op_type = "unbind"
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], ['out0', 'out1', 'out2'])
|
||||
|
||||
|
||||
class TestUnbindOp1(TestUnbindOp):
|
||||
def initParameters(self):
|
||||
self.axis = 1
|
||||
self.num = 2
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], ['out0', 'out1'])
|
||||
|
||||
def outReshape(self):
|
||||
self.out[0] = self.out[0].reshape((3, 2))
|
||||
self.out[1] = self.out[1].reshape((3, 2))
|
||||
|
||||
|
||||
class TestUnbindOp2(TestUnbindOp):
|
||||
def initParameters(self):
|
||||
self.axis = 2
|
||||
self.num = 2
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], ['out0', 'out1'])
|
||||
|
||||
def outReshape(self):
|
||||
self.out[0] = self.out[0].reshape((3, 2))
|
||||
self.out[1] = self.out[1].reshape((3, 2))
|
||||
|
||||
|
||||
class TestUnbindOp3(TestUnbindOp):
|
||||
def initParameters(self):
|
||||
self.axis = 2
|
||||
self.num = 2
|
||||
|
||||
def setAxis(self):
|
||||
self.attrs = {'axis': -1}
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], ['out0', 'out1'])
|
||||
|
||||
def outReshape(self):
|
||||
self.out[0] = self.out[0].reshape((3, 2))
|
||||
self.out[1] = self.out[1].reshape((3, 2))
|
||||
|
||||
|
||||
class TestUnbindOp4(TestUnbindOp):
|
||||
def initParameters(self):
|
||||
self.axis = 1
|
||||
self.num = 2
|
||||
|
||||
def setAxis(self):
|
||||
self.attrs = {'axis': -2}
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], ['out0', 'out1'])
|
||||
|
||||
def outReshape(self):
|
||||
self.out[0] = self.out[0].reshape((3, 2))
|
||||
self.out[1] = self.out[1].reshape((3, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue