|
|
|
@ -49,10 +49,11 @@ static inline framework::DDim ComputeAndCheckShape(
|
|
|
|
|
// check all shape in run time
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
inputs_dims[0][j], inputs_dims[i][j],
|
|
|
|
|
"ShapeError: Dimension %d in inputs' shapes must be equal. "
|
|
|
|
|
"But recevied input[0]'s shape = "
|
|
|
|
|
"[%s], input[%d]'s shape = [%s].",
|
|
|
|
|
j, inputs_dims[0], i, inputs_dims[i]);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of input[%d] must be equal to input[0]. "
|
|
|
|
|
"But received input[0]'s shape = "
|
|
|
|
|
"[%s], input[%d]'s shape = [%s].",
|
|
|
|
|
i, inputs_dims[0], i, inputs_dims[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -78,7 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
|
|
|
|
|
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ins[0], platform::errors::NotFound(
|
|
|
|
|
" The first input of concat should not be null."));
|
|
|
|
|
auto axis = ctx.Attr<int>("axis");
|
|
|
|
|
bool need_resize_out_dims = false;
|
|
|
|
|
if (ctx.HasInput("AxisTensor")) {
|
|
|
|
@ -178,7 +181,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ins[0], platform::errors::NotFound(
|
|
|
|
|
"The first input of concat should not be null."));
|
|
|
|
|
|
|
|
|
|
auto axis = ctx.Attr<int>("axis");
|
|
|
|
|
if (ctx.HasInput("AxisTensor")) {
|
|
|
|
|