|
|
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include "paddle/platform/cuda_helper.h"
|
|
|
|
@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
|
|
|
|
|
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
|
|
|
|
|
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
|
|
|
|
|
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
|
|
|
|
|
|
|
|
|
|
namespace scatter {
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size>
|
|
|
|
|
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
|
|
|
|
|
T* out, const int64_t* out_rows,
|
|
|
|
|
size_t out_rows_size, int64_t row_numel) {
|
|
|
|
|
const int ty = blockIdx.y;
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
__shared__ size_t out_idx;
|
|
|
|
|
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
for (size_t i = 0; i < out_rows_size; i++) {
|
|
|
|
|
if (input_rows[ty] == out_rows[i]) {
|
|
|
|
|
out_idx = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
input += ty * row_numel;
|
|
|
|
|
out += out_idx * row_numel;
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
paddle::platform::CudaAtomicAdd(out + index, input[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
|
|
|
|
|
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<T>();
|
|
|
|
|
auto* input_data = input.value().data<T>();
|
|
|
|
|
|
|
|
|
|
const int block_size = 256;
|
|
|
|
|
dim3 threads(block_size, 1);
|
|
|
|
|
dim3 grid1(1, input_rows.size());
|
|
|
|
|
|
|
|
|
|
MergeAddKernel<
|
|
|
|
|
T, 256><<<grid1, threads, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(input_data, input.rows().data(), out_data,
|
|
|
|
|
out.rows().data(), out.rows().size(),
|
|
|
|
|
input_width);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct MergeAdd<platform::CUDADeviceContext, float>;
|
|
|
|
|
template struct MergeAdd<platform::CUDADeviceContext, double>;
|
|
|
|
|
template struct MergeAdd<platform::CUDADeviceContext, int>;
|
|
|
|
|
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size>
|
|
|
|
|
__global__ void UpdateToTensorKernel(const T* selected_rows,
|
|
|
|
|
const int64_t* rows, const ScatterOps& op,
|
|
|
|
|
T* tensor_out, int64_t row_numel) {
|
|
|
|
|
const int ty = blockIdx.y;
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
|
selected_rows += ty * row_numel;
|
|
|
|
|
tensor_out += rows[ty] * row_numel;
|
|
|
|
|
// FIXME(typhoonzero): use macro fix the below messy code.
|
|
|
|
|
switch (op) {
|
|
|
|
|
case ScatterOps::ASSIGN:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] = selected_rows[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::ADD:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] += selected_rows[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::SUB:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] -= selected_rows[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::SUBBY:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] = selected_rows[index] - tensor_out[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::MUL:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] *= selected_rows[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::DIV:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] /= selected_rows[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case ScatterOps::DIVBY:
|
|
|
|
|
for (int index = tid; index < row_numel; index += block_size) {
|
|
|
|
|
tensor_out[index] = selected_rows[index] / tensor_out[index];
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct UpdateToTensor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const ScatterOps& op, const framework::SelectedRows& input1,
|
|
|
|
|
framework::Tensor* input2) {
|
|
|
|
|
// NOTE: Use SelectedRowsAddToTensor for better performance
|
|
|
|
|
// no additional MergeAdd called.
|
|
|
|
|
MergeAdd<platform::CUDADeviceContext, T> merge_func;
|
|
|
|
|
auto merged_in1 = merge_func(context, input1);
|
|
|
|
|
|
|
|
|
|
auto in1_height = merged_in1.height();
|
|
|
|
|
auto in2_dims = input2->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
|
|
|
|
|
|
|
|
|
|
auto& in1_value = merged_in1.value();
|
|
|
|
|
auto& in1_rows = merged_in1.rows();
|
|
|
|
|
|
|
|
|
|
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
|
|
|
|
|
|
|
|
|
|
auto* in1_data = in1_value.template data<T>();
|
|
|
|
|
auto* in2_data = input2->data<T>();
|
|
|
|
|
|
|
|
|
|
dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
|
|
|
|
|
dim3 grid(1, in1_rows.size());
|
|
|
|
|
UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
|
|
|
|
|
grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op,
|
|
|
|
|
in2_data, in1_row_numel);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace scatter
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|