fix code style

revert-4814-Add_sequence_project_op
qijun 7 years ago
parent 7ef568e893
commit 89758adb83

@ -1,8 +1,11 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

@ -18,7 +18,7 @@ namespace framework {
class SelectedRowsTester : public ::testing::Test {
public:
virtual void SetUp() override {
Vector<int64_t> rows{0, 4, 7};
std::vector<int64_t> rows{0, 4, 7};
int64_t height = 10;
int64_t row_numel = 100;
selected_rows_.reset(new SelectedRows(rows, height));

@ -1,8 +1,11 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

@ -214,10 +214,8 @@ template struct SelectedRowsAdd<platform::GPUPlace, float>;
namespace {
template <typename T>
__global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
const int64_t* rows,
T* tensor_out,
int64_t row_numel,
int block_size) {
const int64_t* rows, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
@ -261,11 +259,11 @@ struct SelectedRowsAddTensor<platform::GPUPlace, T> {
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in1_height);
SelectedRowsAddTensorKernel<T><<<
grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream()
>>>(in1_data, in1_rows.data(),
out_data, in1_row_numel, block_size);
SelectedRowsAddTensorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in1_data, in1_rows.data(), out_data,
in1_row_numel, block_size);
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);

Loading…
Cancel
Save