@ -12,13 +12,16 @@
See the License for the specific language governing permissions and
limitations under the License . */
# include "paddle/framework/op_registry.h"
# include <glog/logging.h>
# include <gtest/gtest.h>
# include "paddle/framework/op_registry.h"
namespace pd = paddle : : framework ;
namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public :
using OperatorBase : : OperatorBase ;
@ -252,7 +255,6 @@ TEST(OperatorRegistrar, CPU) {
op - > Run ( scope , cpu_place ) ;
}
# ifdef PADDLE_WITH_CUDA
TEST ( OperatorRegistrar , CUDA ) {
paddle : : framework : : proto : : OpDesc op_desc ;
paddle : : platform : : CUDAPlace cuda_place ( 0 ) ;
@ -263,4 +265,131 @@ TEST(OperatorRegistrar, CUDA) {
op - > Run ( scope , cuda_place ) ;
}
# endif
static int op_test_value = 0 ;
using paddle : : platform : : DeviceContext ;
using paddle : : platform : : CPUDeviceContext ;
using paddle : : platform : : CUDADeviceContext ;
namespace paddle {
namespace framework {
class OpWithMultiKernelTest : public OperatorWithKernel {
public :
using OperatorWithKernel : : OperatorWithKernel ;
protected :
void InferShape ( InferShapeContext * ctx ) const override { }
framework : : OpKernelType GetActualKernelType (
const framework : : ExecutionContext & ctx ) const override {
return framework : : OpKernelType ( proto : : DataType : : FP32 , ctx . device_context ( ) ) ;
}
framework : : OpKernelType GetExpectedKernelType (
const framework : : OpKernelType & kernel ) const override {
return framework : : OpKernelType ( kernel . data_type_ , platform : : CUDAPlace ( 0 ) ,
kernel . data_layout_ ,
framework : : LibraryType : : kCUDNN ) ;
}
} ;
template < typename DeviceContext , typename T >
class OpMultiKernelTest : public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const ;
} ;
template < typename T >
class OpMultiKernelTest < CPUDeviceContext , T >
: public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const {
+ + op_test_value ;
}
} ;
template < typename T >
class OpMultiKernelTest < CUDADeviceContext , T >
: public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const {
- - op_test_value ;
}
} ;
template < typename DeviceContext , typename T >
class OpMultiKernelTest2 : public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const ;
} ;
template < typename T >
class OpMultiKernelTest2 < CPUDeviceContext , T >
: public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const {
op_test_value + = 10 ;
}
} ;
template < typename T >
class OpMultiKernelTest2 < CUDADeviceContext , T >
: public paddle : : framework : : OpKernel < T > {
public :
void Compute ( const paddle : : framework : : ExecutionContext & ctx ) const {
op_test_value - = 10 ;
}
} ;
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT ( op_with_multi_kernel ,
paddle : : framework : : OpWithMultiKernelTest ,
paddle : : framework : : OpKernelTestMaker ) ;
REGISTER_OP_KERNEL (
op_with_multi_kernel , CPU , paddle : : platform : : CPUPlace ,
paddle : : framework : : OpMultiKernelTest < CPUDeviceContext , float > ) ;
REGISTER_OP_KERNEL (
op_with_multi_kernel , MKLDNN , paddle : : platform : : CPUPlace ,
paddle : : framework : : OpMultiKernelTest2 < CPUDeviceContext , float > ) ;
REGISTER_OP_KERNEL (
op_with_multi_kernel , CUDA , paddle : : platform : : CUDAPlace ,
paddle : : framework : : OpMultiKernelTest < CUDADeviceContext , float > ) ;
REGISTER_OP_KERNEL (
op_with_multi_kernel , CUDNN , paddle : : platform : : CUDAPlace ,
paddle : : framework : : OpMultiKernelTest2 < CUDADeviceContext , float > ) ;
TEST ( OperatorRegistrar , OpWithMultiKernel ) {
paddle : : framework : : proto : : OpDesc op_desc ;
paddle : : platform : : CUDAPlace cuda_place ( 0 ) ;
paddle : : platform : : CPUPlace cpu_place ;
paddle : : framework : : Scope scope ;
op_desc . set_type ( " op_with_multi_kernel " ) ;
auto op = paddle : : framework : : OpRegistry : : CreateOp ( op_desc ) ;
// use all available kernels
paddle : : framework : : UseALL ( ) ;
op - > Run ( scope , cuda_place ) ;
EXPECT_EQ ( op_test_value , - 10 ) ;
// remove cuda kernels
paddle : : framework : : UseCPU ( ) ;
op - > Run ( scope , cpu_place ) ;
EXPECT_EQ ( op_test_value , - 9 ) ;
// add cuda kernels
paddle : : framework : : UseCUDA ( ) ;
op - > Run ( scope , cuda_place ) ;
EXPECT_EQ ( op_test_value , - 10 ) ;
// use cudnn kernel
paddle : : framework : : UseCUDNN ( ) ;
op - > Run ( scope , cuda_place ) ;
EXPECT_EQ ( op_test_value , - 20 ) ;
}