|
|
|
@ -40,7 +40,12 @@ void MatMulJitCode::genCode() {
|
|
|
|
|
size_t wgt_offset = 0;
|
|
|
|
|
for (size_t g = 0; g < groups.size(); ++g) {
|
|
|
|
|
size_t x_offset = 0;
|
|
|
|
|
size_t wgt_offset_tmp = 0;
|
|
|
|
|
for (int i = 0; i < g; ++i) {
|
|
|
|
|
wgt_offset_tmp += groups[i] * block_len;
|
|
|
|
|
}
|
|
|
|
|
for (int k = 0; k < k_; ++k) {
|
|
|
|
|
wgt_offset = wgt_offset_tmp;
|
|
|
|
|
vbroadcastss(zmm_t(x_reg_idx), ptr[param_x + x_offset]);
|
|
|
|
|
// clean
|
|
|
|
|
if (k == 0) {
|
|
|
|
@ -49,7 +54,8 @@ void MatMulJitCode::genCode() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < groups[g]; ++i) {
|
|
|
|
|
vmovups(zmm_t(w_reg_idx), ptr[reg_ptr_wgt + wgt_offset]);
|
|
|
|
|
vmovups(zmm_t(w_reg_idx),
|
|
|
|
|
ptr[reg_ptr_wgt + wgt_offset + k * n_ * sizeof(float)]);
|
|
|
|
|
vfmadd231ps(zmm_t(i), zmm_t(w_reg_idx), zmm_t(x_reg_idx));
|
|
|
|
|
wgt_offset += block_len;
|
|
|
|
|
}
|
|
|
|
|