parent
690b04123f
commit
ab5dc9fe18
@ -1,18 +1,22 @@
|
|||||||
if(WITH_GPU)
|
if(WITH_GPU)
|
||||||
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
|
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
|
||||||
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor selected_rows)
|
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
|
||||||
|
nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function)
|
||||||
|
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor)
|
||||||
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
|
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
|
||||||
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
|
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
|
||||||
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
|
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
|
||||||
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
|
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
|
||||||
else()
|
else()
|
||||||
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
|
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
|
||||||
|
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
|
||||||
cc_library(softmax SRCS softmax.cc DEPS operator)
|
cc_library(softmax SRCS softmax.cc DEPS operator)
|
||||||
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
|
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
|
||||||
cc_library(pooling SRCS pooling.cc DEPS device_context)
|
cc_library(pooling SRCS pooling.cc DEPS device_context)
|
||||||
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
|
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor selected_rows)
|
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
|
||||||
|
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
|
||||||
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
|
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
|
||||||
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
|
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
|
||||||
|
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/selected_rows_functor.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
template <typename T>
|
||||||
|
struct SelectedRowsAdd<platform::CPUPlace, T> {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::SelectedRows& input2,
|
||||||
|
framework::SelectedRows* output) {
|
||||||
|
auto in1_height = input1.height();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, input2.height());
|
||||||
|
output->set_height(in1_height);
|
||||||
|
|
||||||
|
auto& in1_rows = input1.rows();
|
||||||
|
auto& in2_rows = input2.rows();
|
||||||
|
std::vector<int64_t> out_rows;
|
||||||
|
out_rows.reserve(in1_rows.size() + in2_rows.size());
|
||||||
|
|
||||||
|
// concat rows
|
||||||
|
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
|
||||||
|
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
|
||||||
|
output->set_rows(out_rows);
|
||||||
|
|
||||||
|
auto* out_value = output->mutable_value();
|
||||||
|
auto& in1_value = input1.value();
|
||||||
|
auto& in2_value = input2.value();
|
||||||
|
|
||||||
|
auto in1_row_numel = in1_value.numel() / in1_rows.size();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
|
||||||
|
|
||||||
|
auto in1_place = input1.place();
|
||||||
|
PADDLE_ENFORCE(platform::is_cpu_place(in1_place));
|
||||||
|
auto in2_place = input2.place();
|
||||||
|
PADDLE_ENFORCE(platform::is_cpu_place(in2_place));
|
||||||
|
auto out_place = context.GetPlace();
|
||||||
|
PADDLE_ENFORCE(platform::is_cpu_place(out_place));
|
||||||
|
|
||||||
|
auto* out_data = out_value->data<T>();
|
||||||
|
auto* in1_data = in1_value.data<T>();
|
||||||
|
memory::Copy(boost::get<platform::CPUPlace>(out_place), out_data,
|
||||||
|
boost::get<platform::CPUPlace>(in1_place), in1_data,
|
||||||
|
in1_value.numel() * sizeof(T));
|
||||||
|
|
||||||
|
auto* in2_data = in2_value.data<T>();
|
||||||
|
memory::Copy(boost::get<platform::CPUPlace>(out_place),
|
||||||
|
out_data + in1_value.numel(),
|
||||||
|
boost::get<platform::CPUPlace>(in2_place), in2_data,
|
||||||
|
in2_value.numel() * sizeof(T));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct SelectedRowsAdd<platform::CPUPlace, float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SelectedRowsAddTensor<platform::CPUPlace, T> {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::Tensor& input2, framework::Tensor* output) {
|
||||||
|
auto in1_height = input1.height();
|
||||||
|
auto in2_dims = input2.dims();
|
||||||
|
auto out_dims = output->dims();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
|
||||||
|
|
||||||
|
auto& in1_value = input1.value();
|
||||||
|
auto& in1_rows = input1.rows();
|
||||||
|
|
||||||
|
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
|
||||||
|
|
||||||
|
SetConstant<platform::CPUPlace, T> functor;
|
||||||
|
functor(context, output, 0.0);
|
||||||
|
|
||||||
|
auto* in1_data = in1_value.data<T>();
|
||||||
|
auto* out_data = output->data<T>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < in1_rows.size(); i++) {
|
||||||
|
for (int64_t j = 0; j < in1_row_numel; j++) {
|
||||||
|
out_data[in1_rows[i] * in1_row_numel + j] +=
|
||||||
|
in1_data[i * in1_row_numel + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
|
||||||
|
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
|
||||||
|
out_eigen.device(*context.GetEigenDevice<platform::CPUPlace>()) =
|
||||||
|
out_eigen + in2_eigen;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct SelectedRowsAddTensor<platform::CPUPlace, float>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,142 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
#include "paddle/operators/math/selected_rows_functor.h"
|
||||||
|
#include "paddle/platform/cuda_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
template <typename T>
|
||||||
|
struct SelectedRowsAdd<platform::GPUPlace, T> {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::SelectedRows& input2,
|
||||||
|
framework::SelectedRows* output) {
|
||||||
|
auto in1_height = input1.height();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, input2.height());
|
||||||
|
output->set_height(in1_height);
|
||||||
|
|
||||||
|
auto& in1_rows = input1.rows();
|
||||||
|
auto& in2_rows = input2.rows();
|
||||||
|
std::vector<int64_t> out_rows;
|
||||||
|
out_rows.reserve(in1_rows.size() + in2_rows.size());
|
||||||
|
|
||||||
|
// concat rows
|
||||||
|
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
|
||||||
|
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
|
||||||
|
output->set_rows(out_rows);
|
||||||
|
|
||||||
|
auto* out_value = output->mutable_value();
|
||||||
|
auto& in1_value = input1.value();
|
||||||
|
auto& in2_value = input2.value();
|
||||||
|
|
||||||
|
auto in1_row_numel = in1_value.numel() / in1_rows.size();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
|
||||||
|
|
||||||
|
auto* out_data = out_value->data<T>();
|
||||||
|
auto* in1_data = in1_value.data<T>();
|
||||||
|
|
||||||
|
auto in1_place = input1.place();
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
|
||||||
|
auto in2_place = input2.place();
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(in2_place));
|
||||||
|
auto out_place = context.GetPlace();
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(out_place));
|
||||||
|
|
||||||
|
memory::Copy(
|
||||||
|
boost::get<platform::GPUPlace>(out_place), out_data,
|
||||||
|
boost::get<platform::GPUPlace>(in1_place), in1_data,
|
||||||
|
in1_value.numel() * sizeof(T),
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
|
||||||
|
|
||||||
|
auto* in2_data = in2_value.data<T>();
|
||||||
|
memory::Copy(
|
||||||
|
boost::get<platform::GPUPlace>(out_place), out_data + in1_value.numel(),
|
||||||
|
boost::get<platform::GPUPlace>(in2_place), in2_data,
|
||||||
|
in2_value.numel() * sizeof(T),
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct SelectedRowsAdd<platform::GPUPlace, float>;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
|
||||||
|
const int64_t* rows, T* tensor_out,
|
||||||
|
int64_t row_numel, int block_size) {
|
||||||
|
const int ty = blockIdx.y;
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
|
selected_rows += ty * row_numel;
|
||||||
|
tensor_out += rows[ty] * row_numel;
|
||||||
|
|
||||||
|
for (int index = tid; index < row_numel; index += block_size) {
|
||||||
|
// Since index in rows of SelectedRows can be duplicate, we can not use
|
||||||
|
// tensor_out[index] += selected_rows[index]; Instead, we have to use
|
||||||
|
// AtomicAdd to avoid concurrent write error.
|
||||||
|
paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SelectedRowsAddTensor<platform::GPUPlace, T> {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::Tensor& input2, framework::Tensor* output) {
|
||||||
|
auto in1_height = input1.height();
|
||||||
|
auto in2_dims = input2.dims();
|
||||||
|
auto out_dims = output->dims();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
|
||||||
|
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
|
||||||
|
|
||||||
|
auto& in1_value = input1.value();
|
||||||
|
auto& in1_rows = input1.rows();
|
||||||
|
|
||||||
|
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
|
||||||
|
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
|
||||||
|
|
||||||
|
auto* in1_data = in1_value.data<T>();
|
||||||
|
auto* in2_data = input2.data<T>();
|
||||||
|
auto* out_data = output->data<T>();
|
||||||
|
|
||||||
|
SetConstant<platform::GPUPlace, T> functor;
|
||||||
|
functor(context, output, 0.0);
|
||||||
|
|
||||||
|
int block_size = 256;
|
||||||
|
dim3 threads(block_size, 1);
|
||||||
|
dim3 grid(1, in1_height);
|
||||||
|
SelectedRowsAddTensorKernel<
|
||||||
|
T><<<grid, threads, 0,
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||||
|
.stream()>>>(in1_data, in1_rows.data(), out_data,
|
||||||
|
in1_row_numel, block_size);
|
||||||
|
|
||||||
|
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
|
||||||
|
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
|
||||||
|
out_eigen.device(*context.GetEigenDevice<platform::GPUPlace>()) =
|
||||||
|
out_eigen + in2_eigen;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template struct SelectedRowsAddTensor<platform::GPUPlace, float>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
// SelectedRows + SelectedRows will simplely concat value and rows.
|
||||||
|
// The real computation happens in dealing with LoDTensor.
|
||||||
|
template <typename Place, typename T>
|
||||||
|
struct SelectedRowsAdd {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::SelectedRows& input2,
|
||||||
|
framework::SelectedRows* output);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
struct SelectedRowsAddTensor {
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::SelectedRows& input1,
|
||||||
|
const framework::Tensor& input2, framework::Tensor* output);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,106 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/selected_rows_functor.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
TEST(selected_rows_functor, cpu_add) {
|
||||||
|
using namespace paddle::framework;
|
||||||
|
using namespace paddle::platform;
|
||||||
|
using namespace paddle::operators::math;
|
||||||
|
|
||||||
|
CPUPlace cpu_place;
|
||||||
|
CPUDeviceContext ctx(cpu_place);
|
||||||
|
SetConstant<CPUPlace, float> functor;
|
||||||
|
int64_t height = 10;
|
||||||
|
int64_t row_numel = 10;
|
||||||
|
|
||||||
|
std::vector<int64_t> rows1{0, 4, 7};
|
||||||
|
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
|
||||||
|
auto* in1_value = selected_rows1->mutable_value();
|
||||||
|
in1_value->mutable_data<float>(
|
||||||
|
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
|
||||||
|
functor(ctx, in1_value, 1.0);
|
||||||
|
|
||||||
|
std::vector<int64_t> rows2{0, 5, 7, 9};
|
||||||
|
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
|
||||||
|
auto* in2_value = selected_rows2->mutable_value();
|
||||||
|
in2_value->mutable_data<float>(
|
||||||
|
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
|
||||||
|
functor(ctx, in2_value, 2.0);
|
||||||
|
|
||||||
|
std::unique_ptr<SelectedRows> output{new SelectedRows()};
|
||||||
|
auto* out_value = output->mutable_value();
|
||||||
|
|
||||||
|
// simplely concat two SelectedRows
|
||||||
|
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
|
||||||
|
|
||||||
|
SelectedRowsAdd<CPUPlace, float> add_functor;
|
||||||
|
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
|
||||||
|
|
||||||
|
auto out_height = output->height();
|
||||||
|
EXPECT_EQ(out_height, height);
|
||||||
|
|
||||||
|
auto& out_rows = output->rows();
|
||||||
|
|
||||||
|
// input1 rows
|
||||||
|
EXPECT_EQ(out_rows[0], 0);
|
||||||
|
EXPECT_EQ(out_rows[1], 4);
|
||||||
|
EXPECT_EQ(out_rows[2], 7);
|
||||||
|
// input2 rows
|
||||||
|
EXPECT_EQ(out_rows[3], 0);
|
||||||
|
EXPECT_EQ(out_rows[4], 5);
|
||||||
|
EXPECT_EQ(out_rows[5], 7);
|
||||||
|
EXPECT_EQ(out_rows[6], 9);
|
||||||
|
|
||||||
|
auto* out_data = output->value().data<float>();
|
||||||
|
// input1 value
|
||||||
|
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
|
||||||
|
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
|
||||||
|
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
|
||||||
|
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
|
||||||
|
// input2 value
|
||||||
|
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
|
||||||
|
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
|
||||||
|
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
|
||||||
|
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
|
||||||
|
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
|
||||||
|
|
||||||
|
std::unique_ptr<Tensor> tensor1{new Tensor()};
|
||||||
|
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
|
||||||
|
functor(ctx, tensor1.get(), 3.0);
|
||||||
|
|
||||||
|
std::unique_ptr<Tensor> tensor2{new Tensor()};
|
||||||
|
tensor2->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
|
||||||
|
|
||||||
|
SelectedRowsAddTensor<CPUPlace, float> add_tensor_functor;
|
||||||
|
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
|
||||||
|
|
||||||
|
auto* tensor2_data = tensor2->data<float>();
|
||||||
|
// row0: 1.0 + 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0);
|
||||||
|
// row1: 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0);
|
||||||
|
// row4 : 1.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0);
|
||||||
|
// row5: 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0);
|
||||||
|
// row6: 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0);
|
||||||
|
// row7: 1.0 + 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0);
|
||||||
|
// row9: 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0);
|
||||||
|
}
|
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
#include "paddle/operators/math/selected_rows_functor.h"
|
||||||
|
|
||||||
|
TEST(selected_rows_functor, gpu_add) {
|
||||||
|
using namespace paddle::framework;
|
||||||
|
using namespace paddle::platform;
|
||||||
|
using namespace paddle::operators::math;
|
||||||
|
|
||||||
|
GPUPlace gpu_place(0);
|
||||||
|
CPUPlace cpu_place;
|
||||||
|
CUDADeviceContext ctx(gpu_place);
|
||||||
|
SetConstant<GPUPlace, float> functor;
|
||||||
|
int64_t height = 10;
|
||||||
|
int64_t row_numel = 10;
|
||||||
|
|
||||||
|
std::vector<int64_t> rows1{0, 4, 7};
|
||||||
|
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
|
||||||
|
auto* in1_value = selected_rows1->mutable_value();
|
||||||
|
in1_value->mutable_data<float>(
|
||||||
|
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
|
||||||
|
functor(ctx, in1_value, 1.0);
|
||||||
|
|
||||||
|
std::vector<int64_t> rows2{0, 5, 7, 9};
|
||||||
|
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
|
||||||
|
auto* in2_value = selected_rows2->mutable_value();
|
||||||
|
in2_value->mutable_data<float>(
|
||||||
|
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
|
||||||
|
functor(ctx, in2_value, 2.0);
|
||||||
|
|
||||||
|
std::unique_ptr<SelectedRows> output{new SelectedRows()};
|
||||||
|
auto* out_value = output->mutable_value();
|
||||||
|
|
||||||
|
// simplely concat two SelectedRows
|
||||||
|
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
|
||||||
|
|
||||||
|
SelectedRowsAdd<GPUPlace, float> add_functor;
|
||||||
|
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
|
||||||
|
|
||||||
|
auto out_height = output->height();
|
||||||
|
EXPECT_EQ(out_height, height);
|
||||||
|
|
||||||
|
auto& out_rows = output->rows();
|
||||||
|
|
||||||
|
// input1 rows
|
||||||
|
EXPECT_EQ(out_rows[0], 0);
|
||||||
|
EXPECT_EQ(out_rows[1], 4);
|
||||||
|
EXPECT_EQ(out_rows[2], 7);
|
||||||
|
// input2 rows
|
||||||
|
EXPECT_EQ(out_rows[3], 0);
|
||||||
|
EXPECT_EQ(out_rows[4], 5);
|
||||||
|
EXPECT_EQ(out_rows[5], 7);
|
||||||
|
EXPECT_EQ(out_rows[6], 9);
|
||||||
|
|
||||||
|
Tensor out_cpu;
|
||||||
|
out_cpu.CopyFrom<float>(*out_value, cpu_place, ctx);
|
||||||
|
ctx.Wait();
|
||||||
|
|
||||||
|
auto* out_cpu_data = out_cpu.data<float>();
|
||||||
|
// input1 value
|
||||||
|
EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0);
|
||||||
|
// input2 value
|
||||||
|
EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
|
||||||
|
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
|
||||||
|
|
||||||
|
std::unique_ptr<Tensor> tensor1{new Tensor()};
|
||||||
|
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
|
||||||
|
functor(ctx, tensor1.get(), 3.0);
|
||||||
|
|
||||||
|
std::unique_ptr<Tensor> tensor2{new Tensor()};
|
||||||
|
tensor2->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
|
||||||
|
|
||||||
|
SelectedRowsAddTensor<GPUPlace, float> add_tensor_functor;
|
||||||
|
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
|
||||||
|
|
||||||
|
Tensor tensor2_cpu;
|
||||||
|
tensor2_cpu.CopyFrom<float>(*tensor2, cpu_place, ctx);
|
||||||
|
ctx.Wait();
|
||||||
|
|
||||||
|
auto* tensor2_cpu_data = tensor2_cpu.data<float>();
|
||||||
|
// row0: 1.0 + 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0);
|
||||||
|
// row1: 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0);
|
||||||
|
// row4 : 1.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0);
|
||||||
|
// row5: 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0);
|
||||||
|
// row6: 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0);
|
||||||
|
// row7: 1.0 + 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0);
|
||||||
|
// row9: 2.0 + 3.0
|
||||||
|
EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0);
|
||||||
|
}
|
Loading…
Reference in new issue