|
|
|
@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
|
|
|
|
|
template struct RowwiseAdd<platform::GPUPlace, float>;
|
|
|
|
|
template struct RowwiseAdd<platform::GPUPlace, double>;
|
|
|
|
|
template struct ColwiseSum<platform::GPUPlace, float>;
|
|
|
|
|
template struct ColwiseSum<platform::GPUPlace, double>;
|
|
|
|
|
// template struct ColwiseSum<platform::GPUPlace, double>;
|
|
|
|
|
// The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
|
|
|
|
|
// and only failed for this case. So reimplemented it.
|
|
|
|
|
template <>
|
|
|
|
|
void ColwiseSum<platform::GPUPlace, double>::operator()(
|
|
|
|
|
const platform::DeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
framework::Tensor* vector) {
|
|
|
|
|
auto in_dims = input.dims();
|
|
|
|
|
auto size = input.numel() / in_dims[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(vector->numel(), size);
|
|
|
|
|
framework::Tensor one;
|
|
|
|
|
one.mutable_data<double>({in_dims[0]}, context.GetPlace());
|
|
|
|
|
SetConstant<platform::GPUPlace, double> set;
|
|
|
|
|
set(context, &one, static_cast<double>(1.0));
|
|
|
|
|
gemv<platform::GPUPlace, double>(context, true, static_cast<int>(in_dims[0]),
|
|
|
|
|
static_cast<int>(in_dims[1]), 1.0,
|
|
|
|
|
input.data<double>(), one.data<double>(),
|
|
|
|
|
0.0, vector->data<double>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|