|
|
|
@ -59,27 +59,27 @@ static void CPUTakeAlongD1(const platform::DeviceContext& ctx,
|
|
|
|
const auto idx_dims = index.dims();
|
|
|
|
const auto idx_dims = index.dims();
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims.size(), 2,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"index of CPUTakeAlongD1 should be 2D.",
|
|
|
|
"index of CPUTakeAlongD1 should be 2D. "
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
idx_dims, idx_dims.size()));
|
|
|
|
idx_dims, idx_dims.size()));
|
|
|
|
PADDLE_ENFORCE_EQ(array_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(array_dims.size(), 2,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"array of CPUTakeAlongD1 should be 2D.",
|
|
|
|
"array of CPUTakeAlongD1 should be 2D. "
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
array_dims, array_dims.size()));
|
|
|
|
array_dims, array_dims.size()));
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims[0], array_dims[0],
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims[0], array_dims[0],
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"The first dimension of index and array of ",
|
|
|
|
"The first dimension of index and array of "
|
|
|
|
"CPUTakeAlongD1 should be equal.",
|
|
|
|
"CPUTakeAlongD1 should be equal. "
|
|
|
|
"But received index shape = [%s], array shape = [%s],",
|
|
|
|
"But received index shape = [%s], array shape = [%s], "
|
|
|
|
"and the first dimensions are %d and %d.", idx_dims,
|
|
|
|
"and the first dimensions are %d and %d.",
|
|
|
|
array_dims, idx_dims[0], array_dims[0]));
|
|
|
|
idx_dims, array_dims, idx_dims[0], array_dims[0]));
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
idx_dims, value->dims(),
|
|
|
|
idx_dims, value->dims(),
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"index and array of CPUTakeAlongD1 should have ", "the same shape.",
|
|
|
|
"index and array of CPUTakeAlongD1 should have the same shape. "
|
|
|
|
"But received index shape = [%s], array shape = [%s].", idx_dims,
|
|
|
|
"But received index shape = [%s], array shape = [%s].",
|
|
|
|
value->dims()));
|
|
|
|
idx_dims, value->dims()));
|
|
|
|
|
|
|
|
|
|
|
|
// UNDERSTAND: no allocations here
|
|
|
|
// UNDERSTAND: no allocations here
|
|
|
|
const T* p_array = array.data<T>();
|
|
|
|
const T* p_array = array.data<T>();
|
|
|
|
@ -119,27 +119,27 @@ static void CPUPutAlongD1(const platform::DeviceContext& ctx,
|
|
|
|
auto idx_dims = index.dims();
|
|
|
|
auto idx_dims = index.dims();
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims.size(), 2,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"index of CPUPutAlongD1 should be 2D.",
|
|
|
|
"index of CPUPutAlongD1 should be 2D. "
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
idx_dims, idx_dims.size()));
|
|
|
|
idx_dims, idx_dims.size()));
|
|
|
|
PADDLE_ENFORCE_EQ(array_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(array_dims.size(), 2,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"array of CPUPutAlongD1 should be 2D.",
|
|
|
|
"array of CPUPutAlongD1 should be 2D. "
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
"But received shape = [%s] and dimension is %d.",
|
|
|
|
array_dims, array_dims.size()));
|
|
|
|
array_dims, array_dims.size()));
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims[0], array_dims[0],
|
|
|
|
PADDLE_ENFORCE_EQ(idx_dims[0], array_dims[0],
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"The first dimension of index and array of ",
|
|
|
|
"The first dimension of index and array of "
|
|
|
|
"CPUPutAlongD1 should be equal.",
|
|
|
|
"CPUPutAlongD1 should be equal. "
|
|
|
|
"But received index shape = [%s], array shape = [%s],",
|
|
|
|
"But received index shape = [%s], array shape = [%s], "
|
|
|
|
"and the first dimensions are %d and %d.", idx_dims,
|
|
|
|
"and the first dimensions are %d and %d.",
|
|
|
|
array_dims, idx_dims[0], array_dims[0]));
|
|
|
|
idx_dims, array_dims, idx_dims[0], array_dims[0]));
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
idx_dims, value.dims(),
|
|
|
|
idx_dims, value.dims(),
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"index and array of CPUPutAlongD1 should have ", "the same shape.",
|
|
|
|
"index and array of CPUPutAlongD1 should have the same shape. "
|
|
|
|
"But received index shape = [%s], array shape = [%s].", idx_dims,
|
|
|
|
"But received index shape = [%s], array shape = [%s].",
|
|
|
|
value.dims()));
|
|
|
|
idx_dims, value.dims()));
|
|
|
|
|
|
|
|
|
|
|
|
// UNDERSTAND: no allocations here
|
|
|
|
// UNDERSTAND: no allocations here
|
|
|
|
T* p_array = array->data<T>();
|
|
|
|
T* p_array = array->data<T>();
|
|
|
|
|