|
|
|
@ -20,7 +20,9 @@ limitations under the License. */
|
|
|
|
|
TEST(selected_rows_functor, gpu_add) {
|
|
|
|
|
paddle::platform::CUDAPlace gpu_place(0);
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::platform::CUDADeviceContext ctx(gpu_place);
|
|
|
|
|
paddle::platform::CUDADeviceContext& ctx =
|
|
|
|
|
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
|
|
|
|
|
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
|
|
|
|
|
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
functor;
|
|
|
|
@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) {
|
|
|
|
|
TEST(selected_rows_functor, gpu_add_to) {
|
|
|
|
|
paddle::platform::CUDAPlace gpu_place(0);
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::platform::CUDADeviceContext ctx(gpu_place);
|
|
|
|
|
paddle::platform::CUDADeviceContext& ctx =
|
|
|
|
|
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
|
|
|
|
|
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
|
|
|
|
|
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
|
|
|
|
|
float>
|
|
|
|
|
functor;
|
|
|
|
|