|
|
|
@ -13,41 +13,50 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
TEST(selected_rows_functor, cpu_add) {
|
|
|
|
|
using namespace paddle::framework;
|
|
|
|
|
using namespace paddle::platform;
|
|
|
|
|
using namespace paddle::operators::math;
|
|
|
|
|
|
|
|
|
|
CPUPlace cpu_place;
|
|
|
|
|
CPUDeviceContext ctx(cpu_place);
|
|
|
|
|
SetConstant<CPUDeviceContext, float> functor;
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::platform::CPUDeviceContext ctx(cpu_place);
|
|
|
|
|
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
functor;
|
|
|
|
|
int64_t height = 10;
|
|
|
|
|
int64_t row_numel = 10;
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> rows1{0, 4, 7};
|
|
|
|
|
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
|
|
|
|
|
new paddle::framework::SelectedRows(rows1, height)};
|
|
|
|
|
auto* in1_value = selected_rows1->mutable_value();
|
|
|
|
|
in1_value->mutable_data<float>(
|
|
|
|
|
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
|
|
|
|
|
paddle::framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(rows1.size()), row_numel}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
functor(ctx, in1_value, 1.0);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> rows2{0, 5, 7, 9};
|
|
|
|
|
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
|
|
|
|
|
new paddle::framework::SelectedRows(rows2, height)};
|
|
|
|
|
auto* in2_value = selected_rows2->mutable_value();
|
|
|
|
|
in2_value->mutable_data<float>(
|
|
|
|
|
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
|
|
|
|
|
paddle::framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(rows2.size()), row_numel}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
functor(ctx, in2_value, 2.0);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SelectedRows> output{new SelectedRows()};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> output{
|
|
|
|
|
new paddle::framework::SelectedRows()};
|
|
|
|
|
auto* out_value = output->mutable_value();
|
|
|
|
|
|
|
|
|
|
// simplely concat two SelectedRows
|
|
|
|
|
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
|
|
|
|
|
out_value->mutable_data<float>(paddle::framework::make_ddim({7, 10}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
|
|
|
|
|
SelectedRowsAdd<CPUDeviceContext, float> add_functor;
|
|
|
|
|
paddle::operators::math::SelectedRowsAdd<paddle::platform::CPUDeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
add_functor;
|
|
|
|
|
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
|
|
|
|
|
|
|
|
|
|
auto out_height = output->height();
|
|
|
|
@ -78,14 +87,20 @@ TEST(selected_rows_functor, cpu_add) {
|
|
|
|
|
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
|
|
|
|
|
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Tensor> tensor1{new Tensor()};
|
|
|
|
|
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
std::unique_ptr<paddle::framework::Tensor> tensor1{
|
|
|
|
|
new paddle::framework::Tensor()};
|
|
|
|
|
tensor1->mutable_data<float>(
|
|
|
|
|
paddle::framework::make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
functor(ctx, tensor1.get(), 3.0);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Tensor> tensor2{new Tensor()};
|
|
|
|
|
tensor2->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
std::unique_ptr<paddle::framework::Tensor> tensor2{
|
|
|
|
|
new paddle::framework::Tensor()};
|
|
|
|
|
tensor2->mutable_data<float>(
|
|
|
|
|
paddle::framework::make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
|
|
|
|
|
SelectedRowsAddTensor<CPUDeviceContext, float> add_tensor_functor;
|
|
|
|
|
paddle::operators::math::SelectedRowsAddTensor<
|
|
|
|
|
paddle::platform::CPUDeviceContext, float>
|
|
|
|
|
add_tensor_functor;
|
|
|
|
|
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
|
|
|
|
|
|
|
|
|
|
auto* tensor2_data = tensor2->data<float>();
|
|
|
|
@ -106,38 +121,46 @@ TEST(selected_rows_functor, cpu_add) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(selected_rows_functor, cpu_add_to) {
|
|
|
|
|
using namespace paddle::framework;
|
|
|
|
|
using namespace paddle::platform;
|
|
|
|
|
using namespace paddle::operators::math;
|
|
|
|
|
|
|
|
|
|
CPUPlace cpu_place;
|
|
|
|
|
CPUDeviceContext ctx(cpu_place);
|
|
|
|
|
SetConstant<CPUDeviceContext, float> functor;
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::platform::CPUDeviceContext ctx(cpu_place);
|
|
|
|
|
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
functor;
|
|
|
|
|
int64_t height = 10;
|
|
|
|
|
int64_t row_numel = 10;
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> rows1{0, 4, 7};
|
|
|
|
|
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
|
|
|
|
|
new paddle::framework::SelectedRows(rows1, height)};
|
|
|
|
|
auto* in1_value = selected_rows1->mutable_value();
|
|
|
|
|
in1_value->mutable_data<float>(
|
|
|
|
|
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
|
|
|
|
|
paddle::framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(rows1.size()), row_numel}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
functor(ctx, in1_value, 1.0);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> rows2{0, 5, 7, 9};
|
|
|
|
|
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
|
|
|
|
|
new paddle::framework::SelectedRows(rows2, height)};
|
|
|
|
|
auto* in2_value = selected_rows2->mutable_value();
|
|
|
|
|
in2_value->mutable_data<float>(
|
|
|
|
|
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
|
|
|
|
|
paddle::framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(rows2.size()), row_numel}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
functor(ctx, in2_value, 2.0);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SelectedRows> output{new SelectedRows()};
|
|
|
|
|
std::unique_ptr<paddle::framework::SelectedRows> output{
|
|
|
|
|
new paddle::framework::SelectedRows()};
|
|
|
|
|
output->set_height(height);
|
|
|
|
|
auto* out_value = output->mutable_value();
|
|
|
|
|
|
|
|
|
|
// simplely concat two SelectedRows
|
|
|
|
|
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
|
|
|
|
|
out_value->mutable_data<float>(paddle::framework::make_ddim({7, 10}),
|
|
|
|
|
cpu_place);
|
|
|
|
|
|
|
|
|
|
SelectedRowsAddTo<CPUDeviceContext, float> add_to_functor;
|
|
|
|
|
paddle::operators::math::SelectedRowsAddTo<paddle::platform::CPUDeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
add_to_functor;
|
|
|
|
|
add_to_functor(ctx, *selected_rows1, 0, output.get());
|
|
|
|
|
add_to_functor(ctx, *selected_rows2, in1_value->numel(), output.get());
|
|
|
|
|
|
|
|
|
@ -169,11 +192,15 @@ TEST(selected_rows_functor, cpu_add_to) {
|
|
|
|
|
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
|
|
|
|
|
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Tensor> tensor1{new Tensor()};
|
|
|
|
|
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
std::unique_ptr<paddle::framework::Tensor> tensor1{
|
|
|
|
|
new paddle::framework::Tensor()};
|
|
|
|
|
tensor1->mutable_data<float>(
|
|
|
|
|
paddle::framework::make_ddim({height, row_numel}), cpu_place);
|
|
|
|
|
functor(ctx, tensor1.get(), 3.0);
|
|
|
|
|
|
|
|
|
|
SelectedRowsAddToTensor<CPUDeviceContext, float> add_to_tensor_functor;
|
|
|
|
|
paddle::operators::math::SelectedRowsAddToTensor<
|
|
|
|
|
paddle::platform::CPUDeviceContext, float>
|
|
|
|
|
add_to_tensor_functor;
|
|
|
|
|
add_to_tensor_functor(ctx, *output, tensor1.get());
|
|
|
|
|
|
|
|
|
|
auto* tensor1_data = tensor1->data<float>();
|
|
|
|
|