use mat attr and refine test (#15448)
* use mat attr and refine test test=develop * add matmul jitcode test=develop * fix mac compile test=developrevert-15296-async_double_buffered_py_reader
parent
b5ebca47a3
commit
ba02ac4692
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen/matmul.h"
|
||||
#include <stddef.h> // offsetof
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/operators/jit/registry.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
void MatMulJitCode::genCode() {
|
||||
preCode();
|
||||
int block, rest;
|
||||
const auto groups = packed_groups(n_, k_, &block, &rest);
|
||||
PADDLE_ENFORCE_GT(groups.front(), 0);
|
||||
|
||||
const int block_len = sizeof(float) * block;
|
||||
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
|
||||
const int w_reg_idx = x_reg_idx - 1;
|
||||
// from packed mov(reg_ptr_wgt, ptr[param_attr + offsetof(matmul_attr_t,
|
||||
// packed_weight)]);
|
||||
mov(reg_ptr_wgt, param_y);
|
||||
size_t z_offset = 0;
|
||||
size_t wgt_offset = 0;
|
||||
for (size_t g = 0; g < groups.size(); ++g) {
|
||||
size_t x_offset = 0;
|
||||
for (int k = 0; k < k_; ++k) {
|
||||
vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
|
||||
// clean
|
||||
if (k == 0) {
|
||||
for (int i = 0; i < groups[g]; ++i) {
|
||||
vxorps(zmm_t(i), zmm_t(i), zmm_t(i));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < groups[g]; ++i) {
|
||||
vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]);
|
||||
vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
|
||||
wgt_offset += block_len;
|
||||
}
|
||||
// last one, save
|
||||
if (k == k_ - 1) {
|
||||
for (int i = 0; i < groups[g]; ++i) {
|
||||
// only rest save should be careful
|
||||
if (rest != 0 && g == groups.size() - 1 && i == groups[g] - 1) {
|
||||
break;
|
||||
}
|
||||
vmovups(ptr[param_z + z_offset + i * block_len], zmm_t(i));
|
||||
}
|
||||
}
|
||||
x_offset += sizeof(float);
|
||||
}
|
||||
z_offset += block_len * groups[g];
|
||||
}
|
||||
|
||||
if (rest != 0) {
|
||||
// below should refine with mask
|
||||
int reg_idx = groups.back() - 1;
|
||||
z_offset = (n_ - rest) * sizeof(float);
|
||||
int inner_block = 8;
|
||||
while (rest > 0) {
|
||||
if (rest >= 8) {
|
||||
inner_block = 8;
|
||||
vmovups(ptr[param_z + z_offset], ymm_t(reg_idx));
|
||||
// shift zmm of inner_block, change reg_idx if update
|
||||
} else if (rest >= 4) {
|
||||
inner_block = 4;
|
||||
vmovups(ptr[param_z + z_offset], xmm_t(reg_idx));
|
||||
} else if (rest >= 2) {
|
||||
inner_block = 2;
|
||||
vmovq(ptr[param_z + z_offset], xmm_t(reg_idx));
|
||||
} else {
|
||||
inner_block = 1;
|
||||
vmovss(ptr[param_z + z_offset], xmm_t(reg_idx));
|
||||
}
|
||||
z_offset += inner_block * sizeof(float);
|
||||
rest -= inner_block;
|
||||
}
|
||||
}
|
||||
|
||||
postCode();
|
||||
}
|
||||
|
||||
class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
|
||||
public:
|
||||
bool UseMe(const matmul_attr_t& attr) const override {
|
||||
return attr.m == 1 && platform::MayIUse(platform::avx512f) &&
|
||||
attr.n % ZMM_FLOAT_BLOCK == 0 && attr.k < 512;
|
||||
}
|
||||
size_t CodeSize(const matmul_attr_t& attr) const override {
|
||||
int block = YMM_FLOAT_BLOCK;
|
||||
if (platform::MayIUse(platform::avx512f)) {
|
||||
block = ZMM_FLOAT_BLOCK;
|
||||
}
|
||||
return 96 + 4 * attr.k * (attr.n / block + 1) * 8;
|
||||
}
|
||||
std::unique_ptr<GenBase> CreateJitCode(
|
||||
const matmul_attr_t& attr) const override {
|
||||
PADDLE_ENFORCE_GT(attr.m, 0);
|
||||
PADDLE_ENFORCE_GT(attr.n, 0);
|
||||
PADDLE_ENFORCE_GT(attr.k, 0);
|
||||
return make_unique<MatMulJitCode>(attr, CodeSize(attr));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace gen = paddle::operators::jit::gen;
|
||||
|
||||
REGISTER_JITKERNEL_GEN(kMatMul, gen::MatMulCreator);
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdlib.h> // for malloc and free
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
class MatMulJitCode : public JitCode {
|
||||
public:
|
||||
explicit MatMulJitCode(const matmul_attr_t& attr,
|
||||
size_t code_size = 256 * 1024,
|
||||
void* code_ptr = nullptr)
|
||||
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
|
||||
PADDLE_ENFORCE_EQ(m_, 1, "Only support m==1 yet");
|
||||
this->genCode();
|
||||
}
|
||||
|
||||
virtual const char* name() const {
|
||||
std::string base = "MatMulJitCode";
|
||||
base = base + "_M" + std::to_string(m_) + "_N" + std::to_string(n_) + "_K" +
|
||||
std::to_string(k_);
|
||||
return base.c_str();
|
||||
}
|
||||
void genCode() override;
|
||||
|
||||
private:
|
||||
int m_, n_, k_;
|
||||
|
||||
reg64_t param_x{abi_param1};
|
||||
reg64_t param_y{abi_param2};
|
||||
reg64_t param_z{abi_param3};
|
||||
reg64_t param_attr{abi_param4};
|
||||
reg64_t reg_tmp{rax};
|
||||
|
||||
reg64_t reg_ptr_wgt{r10};
|
||||
};
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue