|
|
|
@ -21,42 +21,38 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
|
|
|
|
|
TEST(scatter, ScatterUpdate) {
|
|
|
|
|
// using namespace paddle::framework;
|
|
|
|
|
// using namespace paddle::platform;
|
|
|
|
|
// using namespace paddle::operators;
|
|
|
|
|
|
|
|
|
|
paddle::framework::Tensor* src = new paddle::framework::Tensor();
|
|
|
|
|
paddle::framework::Tensor* index = new paddle::framework::Tensor();
|
|
|
|
|
paddle::framework::Tensor* output = new paddle::framework::Tensor();
|
|
|
|
|
|
|
|
|
|
float* p_src = nullptr;
|
|
|
|
|
int* p_index = nullptr;
|
|
|
|
|
p_src = src->mutable_data<float>(paddle::framework::make_ddim({1, 4}),
|
|
|
|
|
paddle::platform::CPUPlace());
|
|
|
|
|
p_index = index->mutable_data<int>(paddle::framework::make_ddim({1}),
|
|
|
|
|
paddle::platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) p_src[i] = static_cast<float>(i);
|
|
|
|
|
paddle::framework::Tensor src;
|
|
|
|
|
paddle::framework::Tensor index;
|
|
|
|
|
paddle::framework::Tensor output;
|
|
|
|
|
|
|
|
|
|
auto* p_src = src.mutable_data<float>(paddle::framework::make_ddim({1, 4}),
|
|
|
|
|
paddle::platform::CPUPlace());
|
|
|
|
|
auto* p_index = index.mutable_data<int>(paddle::framework::make_ddim({1}),
|
|
|
|
|
paddle::platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) {
|
|
|
|
|
p_src[i] = static_cast<float>(i);
|
|
|
|
|
}
|
|
|
|
|
p_index[0] = 1;
|
|
|
|
|
|
|
|
|
|
float* p_output = output->mutable_data<float>(
|
|
|
|
|
auto* p_output = output.mutable_data<float>(
|
|
|
|
|
paddle::framework::make_ddim({4, 4}), paddle::platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < output.numel(); ++i) {
|
|
|
|
|
p_output[i] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* cpu_place = new paddle::platform::CPUPlace();
|
|
|
|
|
paddle::platform::CPUDeviceContext ctx(*cpu_place);
|
|
|
|
|
paddle::operators::ScatterAssign<float>(ctx, *src, *index, output);
|
|
|
|
|
paddle::operators::ScatterAssign<float>(ctx, src, index, &output);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f);
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
|
|
|
|
|
for (size_t i = 4; i < 8; ++i) {
|
|
|
|
|
EXPECT_EQ(p_output[i], static_cast<float>(i - 4));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 4; i < 8; ++i)
|
|
|
|
|
EXPECT_EQ(output->data<float>()[i], static_cast<float>(i - 4));
|
|
|
|
|
EXPECT_EQ(output.data<float>()[i], static_cast<float>(i - 4));
|
|
|
|
|
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], 0.0f);
|
|
|
|
|
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
|
|
|
|
|
|
|
|
|
|
delete src;
|
|
|
|
|
delete index;
|
|
|
|
|
delete output;
|
|
|
|
|
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
|
|
|
|
|
}
|
|
|
|
|