|
|
|
@ -257,6 +257,9 @@ void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b,
|
|
|
|
|
|
|
|
|
|
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt,
|
|
|
|
|
int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) {
|
|
|
|
|
if (oc_block == 0) {
|
|
|
|
|
return NNACL_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
// original weight format : ohwi
|
|
|
|
|
int oc_block_num = UP_DIV(batch, oc_block);
|
|
|
|
|
int block_stride = channel * oc_block;
|
|
|
|
|