diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index cef21348e4..7d38d51172 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -65,8 +65,9 @@ class VMulKernelImpl : public VMulKernel<T> { explicit VMulKernelImpl(int d) : VMulKernel<T>() { if (useJIT(d)) { - constexpr size_t sz = 256 * 1024; // TODO(TJ): should be related with d - jitcode_.reset(new gen::VMulJitCode(d, sz)); + // roughly estimate the size of code + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); return; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 593209d42b..667a95fe1a 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -800,7 +800,7 @@ TEST(JitKernel, pool) { EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) != std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d)); - const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfany"); + const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4"); EXPECT_EQ(pvmul_f, pvmul_from_key); const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit"); EXPECT_TRUE(pvmul_from_key2 == nullptr);