Add bfloat16 data type (#25402)

revert-26856-strategy_example2
joanna.wozna.intel 5 years ago committed by GitHub
parent 3ba7b9b567
commit 95e1434bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default:
PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));

@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}};
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::undef;
@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out);
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
#endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);

@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
}
#ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>()));
}
#endif

@ -18,6 +18,7 @@
#include <unordered_map>
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
namespace paddle {
namespace framework {

@ -17,6 +17,8 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
@ -36,15 +38,16 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define _ForEachDataTypeSmall_(callback) \

@ -38,3 +38,25 @@ TEST(DataType, float16) {
std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}
TEST(DataType, bfloat16) {
using paddle::framework::Tensor;
using paddle::platform::CPUPlace;
using paddle::platform::bfloat16;
namespace f = paddle::framework;
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, dtype);
// test bf16 tensor
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
// test bf16 size
EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info
std::string type = "::paddle::platform::bfloat16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
}

@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx));
break;
case proto::VarType::BF16:
framework::VisitDataType(dst_type,
CastDataType<platform::bfloat16>(in, out, ctx));
break;
case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break;

@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_bf16 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::BF16, place,
paddle::framework::DataLayout::kAnyLayout,
paddle::framework::LibraryType::kPlain);
auto kernel_fp32 = paddle::framework::OpKernelType(
paddle::framework::proto::VarType::FP32, place,
paddle::framework::DataLayout::kAnyLayout,
@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::float16>(in_data_bool[i]).x);
}
}
// data type transform from/to bfloat16
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
paddle::platform::bfloat16* ptr =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from bfloat16 to other data types
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to bfloat16
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
}
// transform double to bfloat16
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
}
// transform int to bfloat16
int* in_data_int =
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
}
// transform int64 to bfloat16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
}
// transform bool to bfloat16
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
ptr = out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i].x,
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
}
}
}

@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
// more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#endif
template <typename T>

@ -23,6 +23,7 @@ template <typename T>
static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype;
if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) {

@ -90,32 +90,6 @@ void MemoryOptimizePass::CollectLifeCycle(
}
}
// TODO(Superjomn) Make this a general help method.
int DataTypeToSpace(framework::proto::VarType_Type type) {
switch (type) {
case framework::proto::VarType_Type_BOOL:
return sizeof(bool);
case framework::proto::VarType_Type_FP32:
return sizeof(float);
case framework::proto::VarType_Type_INT32:
return sizeof(int32_t);
case framework::proto::VarType_Type_INT64:
return sizeof(int64_t);
case framework::proto::VarType_Type_INT16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP16:
return sizeof(int16_t);
case framework::proto::VarType_Type_FP64:
return sizeof(double);
case framework::proto::VarType_Type_UINT8:
return sizeof(unsigned char);
case framework::proto::VarType_Type_INT8:
return sizeof(int8_t);
default:
PADDLE_THROW("Unknown data type");
}
}
void MemoryOptimizePass::CollectVarMemorySize(
space_table_t* space_table) const {
const int fake_batch_size = 1;
@ -163,7 +137,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
int size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
(*space_table)[node->Var()->Name()] =
size * DataTypeToSpace(node->Var()->GetDataType());
size * paddle::framework::SizeOfType(node->Var()->GetDataType());
}
}
}

@ -14,15 +14,16 @@
#include <gtest/gtest.h>
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/lite/engine.h"
#include "paddle/fluid/operators/lite/ut_helper.h"
namespace paddle {
namespace inference {
namespace lite {

@ -65,13 +65,14 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16)

@ -34,6 +34,7 @@ namespace math {
using float16 = paddle::platform::float16;
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
@ -41,16 +42,18 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
DEFINE_CPU_TRANS(1);

@ -136,6 +136,8 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
cc_test(bfloat16_test SRCS bfloat16_test.cc DEPS lod_tensor)
nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,162 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/bfloat16.h"
#include <vector>
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
namespace paddle {
namespace platform {
using bfloat16 = paddle::platform::bfloat16;
TEST(bfloat16, conversion_cpu) {
// Conversion from float
EXPECT_EQ(bfloat16(1.0f).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5f).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333f).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0f).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0f).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0f).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0f).x, 0x4780);
// Conversion from double
EXPECT_EQ(bfloat16(1.0).x, 0x3f80);
EXPECT_EQ(bfloat16(0.5).x, 0x3f00);
EXPECT_EQ(bfloat16(0.33333).x, 0x3eaa);
EXPECT_EQ(bfloat16(0.0).x, 0x0000);
EXPECT_EQ(bfloat16(-0.0).x, 0x8000);
EXPECT_EQ(bfloat16(65504.0).x, 0x477f);
EXPECT_EQ(bfloat16(65536.0).x, 0x4780);
// Conversion from int
EXPECT_EQ(bfloat16(-1).x, 0xbf80);
EXPECT_EQ(bfloat16(0).x, 0x0000);
EXPECT_EQ(bfloat16(1).x, 0x3f80);
EXPECT_EQ(bfloat16(2).x, 0x4000);
EXPECT_EQ(bfloat16(3).x, 0x4040);
// Conversion from bool
EXPECT_EQ(bfloat16(true).x, 0x3f80);
EXPECT_EQ(bfloat16(false).x, 0x0000);
// Assignment operator
bfloat16 v_assign;
v_assign = bfloat16(0.f);
EXPECT_EQ(v_assign.x, 0x0000);
v_assign = 0.5f;
EXPECT_EQ(v_assign.x, 0x3f00);
v_assign = 0.33333;
EXPECT_EQ(v_assign.x, 0x3eaa);
v_assign = -1;
EXPECT_EQ(v_assign.x, 0xbf80);
// Conversion operator
EXPECT_EQ(static_cast<float>(bfloat16(0.5f)), 0.5f);
EXPECT_NEAR(static_cast<double>(bfloat16(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(bfloat16(-1)), -1);
EXPECT_EQ(static_cast<bool>(bfloat16(true)), true);
}
TEST(bfloat16, arithmetic_cpu) {
EXPECT_NEAR(static_cast<float>(bfloat16(1) + bfloat16(1)), 2, 0.001);
EXPECT_EQ(static_cast<float>(bfloat16(5) + bfloat16(-5)), 0);
EXPECT_NEAR(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(3) - bfloat16(5)), -2);
EXPECT_NEAR(static_cast<float>(bfloat16(0.66667f) - bfloat16(0.33333f)),
0.33334f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(3.3f) * bfloat16(2.0f)), 6.6f, 0.01);
EXPECT_NEAR(static_cast<float>(bfloat16(-2.1f) * bfloat16(-3.0f)), 6.3f, 0.1);
EXPECT_NEAR(static_cast<float>(bfloat16(2.0f) / bfloat16(3.0f)), 0.66667f,
0.01);
EXPECT_EQ(static_cast<float>(bfloat16(1.0f) / bfloat16(2.0f)), 0.5f);
EXPECT_EQ(static_cast<float>(-bfloat16(512.0f)), -512.0f);
EXPECT_EQ(static_cast<float>(-bfloat16(-512.0f)), 512.0f);
}
TEST(bfloat16, comparison_cpu) {
EXPECT_TRUE(bfloat16(1.0f) == bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-1.0f) == bfloat16(-0.5f));
EXPECT_TRUE(bfloat16(1.0f) != bfloat16(0.5f));
EXPECT_FALSE(bfloat16(-1.0f) != bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) < bfloat16(2.0f));
EXPECT_FALSE(bfloat16(-1.0f) < bfloat16(-1.0f));
EXPECT_TRUE(bfloat16(1.0f) <= bfloat16(1.0f));
EXPECT_TRUE(bfloat16(2.0f) > bfloat16(1.0f));
EXPECT_FALSE(bfloat16(-2.0f) > bfloat16(-2.0f));
EXPECT_TRUE(bfloat16(2.0f) >= bfloat16(2.0f));
}
TEST(bfloat16, lod_tensor_cpu) {
framework::LoDTensor lod_tensor;
std::vector<bfloat16> input_data = {bfloat16(1.0f), bfloat16(0.5f),
bfloat16(0.33333f), bfloat16(0.0f)};
EXPECT_EQ(input_data[0].x, 0x3f80);
EXPECT_EQ(input_data[1].x, 0x3f00);
EXPECT_EQ(input_data[2].x, 0x3eaa);
EXPECT_EQ(input_data[3].x, 0x0000);
lod_tensor.Resize({4, 1});
lod_tensor.set_lod(framework::LoD({{0, 2, 4}}));
bfloat16* data_ptr = lod_tensor.mutable_data<bfloat16>(CPUPlace());
EXPECT_NE(data_ptr, nullptr);
EXPECT_EQ(input_data.size(), static_cast<size_t>(lod_tensor.numel()));
for (size_t i = 0; i < input_data.size(); ++i) {
data_ptr[i] = input_data[i];
EXPECT_EQ(data_ptr[i].x, input_data[i].x);
}
}
TEST(bfloat16, floating) {
// compile time assert.
PADDLE_ENFORCE_EQ(
std::is_floating_point<bfloat16>::value, true,
platform::errors::Fatal("std::is_floating_point with bfloat16 data type "
"should be equal to true but it is not"));
}
TEST(bfloat16, print) {
bfloat16 a = bfloat16(1.0f);
std::cout << a << std::endl;
}
// CPU test
TEST(bfloat16, isinf) {
bfloat16 a;
a.x = 0x7f80;
bfloat16 b = bfloat16(INFINITY);
bfloat16 c = static_cast<bfloat16>(INFINITY);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(c), true);
}
TEST(bfloat16, isnan) {
bfloat16 a;
a.x = 0x7fff;
bfloat16 b = bfloat16(NAN);
bfloat16 c = static_cast<bfloat16>(NAN);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(c), true);
}
} // namespace platform
} // namespace paddle

@ -161,6 +161,12 @@ inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8;
}
template <>
inline mkldnn::memory::data_type
MKLDNNGetDataType<paddle::platform::bfloat16>() {
return mkldnn::memory::data_type::bf16;
}
inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst);

@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "pybind11/numpy.h"
@ -104,6 +105,7 @@ struct ValidDTypeToPyArrayChecker {
}
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
@ -119,6 +121,9 @@ inline std::string TensorDTypeToPyDTypeStr(
if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \
return "e"; \
} else if (std::is_same<T, platform::bfloat16>::value) { \
/* NumPy character code of uint16 due to no support for bfloat16 */ \
return "H"; \
} else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE_EQ( \
@ -262,10 +267,10 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// TODO(cql): temporary keeping uint16, which is used for casting float16
// before. It should be depracated later.
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy);
// since there is still no support for bfloat16 in NumPy,
// uint16 is used for casting bfloat16
SetTensorFromPyArrayT<paddle::platform::bfloat16, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<bool>>(array)) {
SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
} else {
@ -479,6 +484,8 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
switch (src_type) {
case framework::proto::VarType::FP16:
return _sliceAndConcat<paddle::platform::float16>(self, obj, dim);
case framework::proto::VarType::BF16:
return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
case framework::proto::VarType::FP32:
return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64:

Loading…
Cancel
Save