@ -176,7 +176,6 @@ void TestMainImpl(std::string func_name, std::string code_str,
bool is_float16 = std : : type_index ( typeid ( T ) ) = =
std : : type_index ( typeid ( paddle : : platform : : float16 ) ) ;
paddle : : framework : : InitDevices ( false , { 0 } ) ;
paddle : : platform : : CUDAPlace place = paddle : : platform : : CUDAPlace ( 0 ) ;
paddle : : platform : : CUDADeviceCode device_code ( place , func_name , code_str ) ;
device_code . Compile ( is_float16 ) ;
@ -266,7 +265,7 @@ void TestElementwiseMain(
}
int n = cpu_tensors [ 0 ] . numel ( ) ;
if ( dtype = = " float16 " ) {
if ( dtype = = " __half " ) {
TestMainImpl < paddle : : platform : : float16 > ( func_name , code_str , cpu_tensors , n ,
input_ids , output_ids ) ;
} else {
@ -275,7 +274,7 @@ void TestElementwiseMain(
}
// Check the results
float eps = ( dtype = = " float16 " ) ? 1E-2 : 1E-5 ;
float eps = ( dtype = = " __half " ) ? 1E-2 : 1E-5 ;
for ( int i = 0 ; i < n ; i + + ) {
fusion_group : : CheckOutput ( expressions , cpu_tensors , input_ids , output_ids ,
i , eps ) ;
@ -312,7 +311,7 @@ void TestMain(fusion_group::SubGraph* subgraph, std::vector<int> input_ids,
}
TEST ( code_generator , elementwise ) {
for ( std : : string dtype : { " float " , " float16 " } ) {
for ( std : : string dtype : { " float " , " __half " } ) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
@ -342,7 +341,7 @@ TEST(code_generator, elementwise) {
}
TEST ( code_generator , elementwise_grad ) {
for ( std : : string dtype : { " float " , " float16 " } ) {
for ( std : : string dtype : { " float " , " __half " } ) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
@ -407,7 +406,7 @@ std::unique_ptr<paddle::framework::ir::Graph> BuildGraph(bool backward,
std : : unique_ptr < paddle : : framework : : ir : : Graph > graph (
new paddle : : framework : : ir : : Graph ( layers . main_program ( ) ) ) ;
auto proto_dtype = ( dtype = = " float16 " )
auto proto_dtype = ( dtype = = " __half " )
? paddle : : framework : : proto : : VarType : : FP16
: paddle : : framework : : proto : : VarType : : FP32 ;
for ( auto * n : graph - > Nodes ( ) ) {
@ -463,10 +462,10 @@ std::unordered_set<paddle::framework::ir::Node*> DistilGradNodes(
}
TEST ( code_generator , subgraph ) {
for ( std : : string dtype : { " float " , " float16 " } ) {
for ( std : : string dtype : { " float " , " __half " } ) {
std : : unique_ptr < paddle : : framework : : ir : : Graph > graph =
BuildGraph ( false , dtype ) ;
fusion_group : : SubGraph subgraph ( 0 , " elementwise_kernel_1 " , tru e,
fusion_group : : SubGraph subgraph ( 0 , " elementwise_kernel_1 " , fals e,
graph - > Nodes ( ) ) ;
// Expressions generated by code_generator (they may be different):
@ -482,10 +481,10 @@ TEST(code_generator, subgraph) {
}
TEST ( code_generator , subgraph_grad ) {
for ( std : : string dtype : { " float " , " float16 " } ) {
for ( std : : string dtype : { " float " , " __half " } ) {
std : : unique_ptr < paddle : : framework : : ir : : Graph > graph =
BuildGraph ( true , dtype ) ;
fusion_group : : SubGraph subgraph ( 0 , " elementwise_grad_kernel_1 " , tru e,
fusion_group : : SubGraph subgraph ( 0 , " elementwise_grad_kernel_1 " , fals e,
DistilGradNodes ( graph ) ) ;
// Expressions generated by code_generator (they may be different):