Fix more CPPLint issues in fluid/operators/math (#10276)

* Fix CPPLint issues in lstm_cpu_kernel.h

* Fix CPPLint issues in math/math_function_test

* Fix CPPLint issues in math/math_function_test

* Fix CPPLint issues in math/concat.cc

* Fix CPPLint issues in math/concat.cc

* Fix CPPLint issues in math/concat.cc

* Fix CPPLint issues in math/gru_cpu_kernel

* Fix CPPLint issues in math/selected_rows_functor_test.cu

* Fix compile error

* Fix compile error
trainerSaveLoadParams
Abhinav Arora 8 years ago committed by GitHub
parent fb7ca48c06
commit 738585476d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -87,7 +87,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs);
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), &outputs);
}
}
};

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <vector>
namespace paddle {
namespace operators {
@ -70,20 +71,20 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
std::vector<framework::Tensor>* outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int num = outputs->size();
int input_rows = 1;
auto dim_0 = outputs[0].dims();
auto dim_0 = outputs->at(0).dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;
std::vector<int64_t> output_cols(outputs.size());
std::vector<int64_t> output_cols(outputs->size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs[i].numel() / input_rows;
int t_cols = outputs->at(i).numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
@ -95,7 +96,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = output_cols[j];
T* dst_ptr = outputs[j].data<T>() + k * col_len;
T* dst_ptr = outputs->at(j).data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
col_idx += col_len;

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_primitives.h"
@ -202,16 +204,16 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
std::vector<framework::Tensor>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs.size();
int o_num = outputs->size();
int out_row = 1;
auto dim_0 = outputs[0].dims();
auto dim_0 = outputs->at(0).dims();
for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i];
}
int out_col = outputs[0].numel() / out_row;
int out_col = outputs->at(0).numel() / out_row;
int in_col = 0, in_row = out_row;
bool sameShape = true;
@ -221,13 +223,13 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
int t_col = outputs[i].numel() / out_row;
int t_col = outputs->at(i).numel() / out_row;
if (sameShape) {
if (t_col != out_col) sameShape = false;
}
in_col += t_col;
outputs_cols[i + 1] = in_col;
outputs_ptr[i] = outputs[i].data<T>();
outputs_ptr[i] = outputs->at(i).data<T>();
}
T** dev_out_gpu_data =

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
@ -56,7 +57,7 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const int axis, std::vector<framework::Tensor>& outputs);
const int axis, std::vector<framework::Tensor>* outputs);
};
} // namespace math

@ -89,14 +89,14 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
__m256 r_value_reset_gate;
__m256 r_value_reset_output;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 *update_gate = (__m256 *)gate_value;
__m256 *reset_gate = (__m256 *)(gate_value + frame_size);
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
__m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size);
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i];
if (prev_output_value) {
r_prev_out = ((__m256 *)prev_output_value)[i];
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
}
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
@ -104,7 +104,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate;
((__m256 *)reset_output_value)[i] = r_value_reset_output;
(reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output;
}
#endif
}
@ -119,21 +119,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
__m256 r_value_frame_state;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 r_output;
__m256 *update_gate = (__m256 *)gate_value;
__m256 *frame_state = (__m256 *)(gate_value + frame_size * 2);
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
__m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i];
if (prev_output_value) {
r_prev_out = ((__m256 *)prev_output_value)[i];
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
}
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
r_output, active_node);
frame_state[i] = r_value_frame_state;
((__m256 *)output_value)[i] = r_output;
(reinterpret_cast<__m256 *>(output_value))[i] = r_output;
}
#endif
}
@ -284,20 +284,22 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
__m256 r_out_grad;
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *update_gate_value = (__m256 *)gate_value;
__m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2);
__m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2);
__m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
__m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
__m256 *frame_state_value =
reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
__m256 *frame_state_grad =
reinterpret_cast<__m256 *>(gate_grad + frame_size * 2);
for (int i = 0; i < frame_size / 8; i++) {
r_update_gate_value = update_gate_value[i];
r_frame_state_value = frame_state_value[i];
r_out_grad = ((__m256 *)output_grad)[i];
r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i];
if (prev_out_value) {
r_prev_out_value = ((__m256 *)prev_out_value)[i];
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i];
}
if (prev_out_grad) {
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
}
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
@ -307,7 +309,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
if (prev_out_grad) {
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
(reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
}
}
#endif
@ -327,10 +329,11 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
__m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
__m256 *update_gate_value = (__m256 *)gate_value;
__m256 *update_gate_grad = (__m256 *)gate_grad;
__m256 *reset_gate_value = (__m256 *)(gate_value + frame_size);
__m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size);
__m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
__m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
__m256 *reset_gate_value =
reinterpret_cast<__m256 *>(gate_value + frame_size);
__m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
for (int i = 0; i < frame_size / 8; i++) {
r_update_gate_value = update_gate_value[i];
@ -338,13 +341,13 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_gate_value = reset_gate_value[i];
if (prev_out_value && prev_out_grad) {
r_reset_output_grad = ((__m256 *)reset_output_grad)[i];
r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
}
if (prev_out_value) {
r_prev_out_value = ((__m256 *)prev_out_value)[i];
r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i];
}
if (prev_out_grad) {
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
}
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
@ -354,7 +357,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad;
if (prev_out_grad) {
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
(reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
}
}
#endif

@ -164,10 +164,12 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_state_atv;
__m256 r_out;
__m256 *value_in = (__m256 *)value.gate_value;
__m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i];
@ -175,13 +177,13 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_value_fg = value_fg[i];
r_value_og = value_og[i];
if (value.check_ig) {
r_checkI = ((__m256 *)value.check_ig)[i];
r_checkF = ((__m256 *)value.check_fg)[i];
r_checkO = ((__m256 *)value.check_og)[i];
r_checkI = (reinterpret_cast<__m256 *>(value.check_ig))[i];
r_checkF = (reinterpret_cast<__m256 *>(value.check_fg))[i];
r_checkO = (reinterpret_cast<__m256 *>(value.check_og))[i];
}
if (value.prev_state_value) {
r_prev_state = ((__m256 *)value.prev_state_value)[i];
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
@ -192,9 +194,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
value_ig[i] = r_value_ig;
value_fg[i] = r_value_fg;
value_og[i] = r_value_og;
((__m256 *)value.state_value)[i] = r_state;
((__m256 *)value.state_active_value)[i] = r_state_atv;
((__m256 *)value.output_value)[i] = r_out;
(reinterpret_cast<__m256 *>(value.state_value))[i] = r_state;
(reinterpret_cast<__m256 *>(value.state_active_value))[i] = r_state_atv;
(reinterpret_cast<__m256 *>(value.output_value))[i] = r_out;
}
#endif
}
@ -227,14 +229,16 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 r_checkFGrad;
__m256 r_checkOGrad;
__m256 *value_in = (__m256 *)value.gate_value;
__m256 *value_ig = (__m256 *)(value.gate_value + frame_size);
__m256 *value_fg = (__m256 *)(value.gate_value + frame_size * 2);
__m256 *value_og = (__m256 *)(value.gate_value + frame_size * 3);
__m256 *grad_in = (__m256 *)grad.gate_grad;
__m256 *grad_ig = (__m256 *)(grad.gate_grad + frame_size);
__m256 *grad_fg = (__m256 *)(grad.gate_grad + frame_size * 2);
__m256 *grad_og = (__m256 *)(grad.gate_grad + frame_size * 3);
__m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value);
__m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size);
__m256 *value_fg =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2);
__m256 *value_og =
reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3);
__m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad);
__m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size);
__m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2);
__m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3);
for (int i = 0; i < frame_size / 8; i++) {
r_value_in = value_in[i];
@ -242,16 +246,16 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_value_fg = value_fg[i];
r_value_og = value_og[i];
if (value.check_ig) {
r_checkI = ((__m256 *)value.check_ig)[i];
r_checkF = ((__m256 *)value.check_fg)[i];
r_checkO = ((__m256 *)value.check_og)[i];
r_checkI = (reinterpret_cast<__m256 *>(value.check_ig))[i];
r_checkF = (reinterpret_cast<__m256 *>(value.check_fg))[i];
r_checkO = (reinterpret_cast<__m256 *>(value.check_og))[i];
}
r_state = ((__m256 *)value.state_value)[i];
r_state_atv = ((__m256 *)value.state_active_value)[i];
r_output_grad = ((__m256 *)grad.output_grad)[i];
r_state_grad = ((__m256 *)grad.state_grad)[i];
r_state = (reinterpret_cast<__m256 *>(value.state_value))[i];
r_state_atv = (reinterpret_cast<__m256 *>(value.state_active_value))[i];
r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i];
r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i];
if (value.prev_state_value) {
r_prev_state = ((__m256 *)value.prev_state_value)[i];
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
@ -264,15 +268,18 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
grad_ig[i] = r_grad_ig;
grad_fg[i] = r_grad_fg;
grad_og[i] = r_grad_og;
((__m256 *)grad.state_grad)[i] = r_state_grad;
(reinterpret_cast<__m256 *>(grad.state_grad))[i] = r_state_grad;
if (grad.prev_state_grad)
((__m256 *)grad.prev_state_grad)[i] = r_prev_state_grad;
(reinterpret_cast<__m256 *>(grad.prev_state_grad))[i] = r_prev_state_grad;
if (value.prev_state_value) {
if (grad.check_ig_grad) ((__m256 *)grad.check_ig_grad)[i] += r_checkIGrad;
if (grad.check_fg_grad) ((__m256 *)grad.check_fg_grad)[i] += r_checkFGrad;
if (grad.check_ig_grad)
(reinterpret_cast<__m256 *>(grad.check_ig_grad))[i] += r_checkIGrad;
if (grad.check_fg_grad)
(reinterpret_cast<__m256 *>(grad.check_fg_grad))[i] += r_checkFGrad;
}
if (grad.check_og_grad) ((__m256 *)grad.check_og_grad)[i] += r_checkOGrad;
if (grad.check_og_grad)
(reinterpret_cast<__m256 *>(grad.check_og_grad))[i] += r_checkOGrad;
}
#endif
}

File diff suppressed because it is too large Load Diff

@ -12,43 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
TEST(selected_rows_functor, gpu_add) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
CUDAPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<CUDADeviceContext, float> functor;
paddle::platform::CUDAPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::platform::CUDADeviceContext ctx(gpu_place);
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
float>
functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
gpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
gpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
// simply concat two SelectedRows
out_value->mutable_data<float>(paddle::framework::make_ddim({7, 10}),
gpu_place);
SelectedRowsAdd<CUDADeviceContext, float> add_functor;
paddle::operators::math::SelectedRowsAdd<paddle::platform::CUDADeviceContext,
float>
add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto out_height = output->height();
@ -66,8 +75,8 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
paddle::framework::Tensor out_cpu;
paddle::framework::TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
@ -83,18 +92,24 @@ TEST(selected_rows_functor, gpu_add) {
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
std::unique_ptr<paddle::framework::Tensor> tensor1{
new paddle::framework::Tensor()};
tensor1->mutable_data<float>(
paddle::framework::make_ddim({height, row_numel}), gpu_place);
functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
std::unique_ptr<paddle::framework::Tensor> tensor2{
new paddle::framework::Tensor()};
tensor2->mutable_data<float>(
paddle::framework::make_ddim({height, row_numel}), gpu_place);
SelectedRowsAddTensor<CUDADeviceContext, float> add_tensor_functor;
paddle::operators::math::SelectedRowsAddTensor<
paddle::platform::CUDADeviceContext, float>
add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
Tensor tensor2_cpu;
TensorCopy(*tensor2, cpu_place, ctx, &tensor2_cpu);
paddle::framework::Tensor tensor2_cpu;
paddle::framework::TensorCopy(*tensor2, cpu_place, ctx, &tensor2_cpu);
ctx.Wait();
auto* tensor2_cpu_data = tensor2_cpu.data<float>();
@ -115,39 +130,47 @@ TEST(selected_rows_functor, gpu_add) {
}
TEST(selected_rows_functor, gpu_add_to) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators::math;
CUDAPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<CUDADeviceContext, float> functor;
paddle::platform::CUDAPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::platform::CUDADeviceContext ctx(gpu_place);
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
float>
functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<SelectedRows> selected_rows1{new SelectedRows(rows1, height)};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), gpu_place);
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
gpu_place);
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), gpu_place);
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
gpu_place);
functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()};
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), gpu_place);
// simply concat two SelectedRows
out_value->mutable_data<float>(paddle::framework::make_ddim({7, 10}),
gpu_place);
SelectedRowsAddTo<CUDADeviceContext, float> add_to_functor;
paddle::operators::math::SelectedRowsAddTo<
paddle::platform::CUDADeviceContext, float>
add_to_functor;
add_to_functor(ctx, *selected_rows1, 0, output.get());
add_to_functor(ctx, *selected_rows2, in1_value->numel(), output.get());
@ -166,8 +189,8 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
Tensor out_cpu;
TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
paddle::framework::Tensor out_cpu;
paddle::framework::TensorCopy(*out_value, cpu_place, ctx, &out_cpu);
ctx.Wait();
auto* out_cpu_data = out_cpu.data<float>();
@ -183,15 +206,19 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), gpu_place);
std::unique_ptr<paddle::framework::Tensor> tensor1{
new paddle::framework::Tensor()};
tensor1->mutable_data<float>(
paddle::framework::make_ddim({height, row_numel}), gpu_place);
functor(ctx, tensor1.get(), 3.0);
SelectedRowsAddToTensor<CUDADeviceContext, float> add_to_tensor_functor;
paddle::operators::math::SelectedRowsAddToTensor<
paddle::platform::CUDADeviceContext, float>
add_to_tensor_functor;
add_to_tensor_functor(ctx, *output, tensor1.get());
Tensor tensor1_cpu;
TensorCopy(*tensor1, cpu_place, ctx, &tensor1_cpu);
paddle::framework::Tensor tensor1_cpu;
paddle::framework::TensorCopy(*tensor1, cpu_place, ctx, &tensor1_cpu);
ctx.Wait();
auto* tensor1_cpu_data = tensor1_cpu.data<float>();

Loading…
Cancel
Save