|
|
|
@ -79,8 +79,16 @@ void ConcatCase1(DeviceContext* context) {
|
|
|
|
|
concat_functor(*context, input, 0, &out);
|
|
|
|
|
|
|
|
|
|
// check the dim of input_a, input_b
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_a.dims(), dim_a));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_b.dims(), dim_b));
|
|
|
|
|
|
|
|
|
|
int* out_ptr = nullptr;
|
|
|
|
|
if (paddle::platform::is_gpu_place(Place())) {
|
|
|
|
@ -95,10 +103,14 @@ void ConcatCase1(DeviceContext* context) {
|
|
|
|
|
int idx_a = 0, idx_b = 0;
|
|
|
|
|
for (int j = 0; j < 5 * 3 * 4; ++j) {
|
|
|
|
|
if (j >= cols) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_b;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -166,8 +178,16 @@ void ConcatCase2(DeviceContext* context) {
|
|
|
|
|
concat_functor(*context, input, 1, &out);
|
|
|
|
|
|
|
|
|
|
// check the dim of input_a, input_b
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_a.dims(), dim_a));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_b.dims(), dim_b));
|
|
|
|
|
|
|
|
|
|
int* out_ptr = nullptr;
|
|
|
|
|
if (paddle::platform::is_gpu_place(Place())) {
|
|
|
|
@ -183,10 +203,16 @@ void ConcatCase2(DeviceContext* context) {
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
for (int j = 0; j < 28; ++j) {
|
|
|
|
|
if (j >= cols) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 28 + j], b_ptr[idx_b],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_b;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 28 + j], a_ptr[idx_a],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -255,8 +281,16 @@ void ConcatCase3(DeviceContext* context) {
|
|
|
|
|
concat_functor(*context, input, 2, &out);
|
|
|
|
|
|
|
|
|
|
// check the dim of input_a, input_b
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_a.dims(), dim_a));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_b.dims(), dim_b));
|
|
|
|
|
|
|
|
|
|
int* out_ptr = nullptr;
|
|
|
|
|
if (paddle::platform::is_gpu_place(Place())) {
|
|
|
|
@ -273,10 +307,16 @@ void ConcatCase3(DeviceContext* context) {
|
|
|
|
|
for (int i = 0; i < 6; ++i) {
|
|
|
|
|
for (int j = 0; j < 9; ++j) {
|
|
|
|
|
if (j >= cols) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 9 + j], b_ptr[idx_b],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_b;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 9 + j], a_ptr[idx_a],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -347,8 +387,16 @@ void ConcatCase4(DeviceContext* context) {
|
|
|
|
|
context->Wait();
|
|
|
|
|
|
|
|
|
|
// check the dim of input_a, input_b
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_a.dims(), dim_a));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b,
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of Input tensor should be the same as the "
|
|
|
|
|
"declared dims. Tensor dims: [%s], declared dims: [%s]",
|
|
|
|
|
input_b.dims(), dim_b));
|
|
|
|
|
|
|
|
|
|
int* out_ptr = nullptr;
|
|
|
|
|
if (paddle::platform::is_gpu_place(Place())) {
|
|
|
|
@ -365,10 +413,16 @@ void ConcatCase4(DeviceContext* context) {
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
for (int j = 0; j < 24; ++j) {
|
|
|
|
|
if (j >= cols) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 24 + j], b_ptr[idx_b],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_b;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_ptr[i * 24 + j], a_ptr[idx_a],
|
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
|
"Concat test failed, the result should be equal."));
|
|
|
|
|
++idx_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|