|
|
@ -23,13 +23,22 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
|
|
|
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
|
|
|
axis = axis + rank;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return axis > 0 ? axis : 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
PADDLE_ENFORCE(ins[0], "The input should not be null.");
|
|
|
|
|
|
|
|
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
|
|
|
|
|
|
|
|
static_cast<int64_t>(ins[0]->dims().size()));
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
out->mutable_data<T>(place);
|
|
|
|
|
|
|
|
|
|
|
@ -83,8 +92,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ins[0], "The input should not be null.");
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
|
|
|
|
|
|
|
|
static_cast<int64_t>(ins[0]->dims().size()));
|
|
|
|
|
|
|
|
|
|
|
|
// get output tensor that the name is not kEmptyVarName
|
|
|
|
// get output tensor that the name is not kEmptyVarName
|
|
|
|
std::vector<framework::Tensor*> outputs;
|
|
|
|
std::vector<framework::Tensor*> outputs;
|
|
|
|