|
|
|
@ -361,7 +361,8 @@ class FCPrimitiveFactory {
|
|
|
|
|
|
|
|
|
|
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
const ExecutionContext& ctx) {
|
|
|
|
|
const std::string key = platform::CreateKey(platform::ThreadIDasStr());
|
|
|
|
|
const std::string key =
|
|
|
|
|
platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
|
|
|
|
|
const std::string weights_key = key + ctx.InputName("W");
|
|
|
|
|
const std::string bias_key = key + ctx.InputName("Bias");
|
|
|
|
|
dev_ctx.SetBlob(weights_key, weights_);
|
|
|
|
@ -532,8 +533,9 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
|
|
|
|
|
bool fuse_relu, bool force_fp32_output) {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const std::string prim_key = platform::CreateKey(
|
|
|
|
|
platform::ThreadIDasStr(), input->format(), input->dims()[0],
|
|
|
|
|
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
|
|
|
|
|
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(),
|
|
|
|
|
input->dims()[0], framework::vectorize<int>(w->dims()),
|
|
|
|
|
ctx.OutputName("Out"));
|
|
|
|
|
constexpr bool is_int8 =
|
|
|
|
|
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
|
|
|
|
|
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
|
|
|
|
|