Merge pull request #12904 from tensor-tang/refine/jit

optimize cpu vec activations
createGenDocLib
tensor-tang 7 years ago committed by GitHub
commit 36363292c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -232,40 +232,28 @@ use lstm_x_t as input and compute as standard LSTM.
template <typename T> template <typename T>
inline void bias_relu(const int n, const T* x, const T* bias, T* y) { inline void bias_relu(const int n, const T* x, const T* bias, T* y) {
if (bias) { if (bias) {
for (int i = 0; i < n; ++i) { math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
y[i] = x[i] + bias[0]; math::vec_relu<T, platform::jit::avx>(n, y, y);
}
math::vec_relu<T>(n, y, y);
} else { } else {
math::vec_relu<T>(n, x, y); math::vec_relu<T, platform::jit::avx>(n, x, y);
} }
} }
template <typename DeviceContext, typename T> template <typename T>
inline void vec_softmax(const math::BlasT<DeviceContext, T>& blas, const int n, inline void vec_softmax(const int n, const T* x, T* y) {
const T* x, T* y) {
T scalar = x[0]; T scalar = x[0];
// max // max
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
scalar = scalar < x[i] ? x[i] : scalar; scalar = scalar < x[i] ? x[i] : scalar;
} }
math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
// sub math::vec_exp<T>(n, y, y); // exp
for (int i = 0; i < n; ++i) {
y[i] = x[i] - scalar;
}
// exp
blas.VEXP(n, y, y);
// sum // sum
scalar = T(0); scalar = T(0);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
scalar += y[i]; scalar += y[i];
} }
math::vec_scal<T>(n, static_cast<T>(1) / scalar, y); // scale
// scale
blas.SCAL(n, static_cast<T>(1) / scalar, y);
} }
template <typename T> template <typename T>
@ -311,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1}); fc_out->Resize({max_seq_len, 1});
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation")); auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
act_cell = act_functor(ctx.Attr<std::string>("cell_activation")); auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation")); auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL; const T* h0_data = h0 ? h0->data<T>() : NULL;
@ -363,7 +361,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
fc_out_data); fc_out_data);
} }
// 1d. softmax // 1d. softmax
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data); vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool // mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data, math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data); cur_x_data, lstm_x_data);

@ -65,3 +65,4 @@ if(WITH_GPU)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,202 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sys/time.h>
#include <cmath>
#include <cstring>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
constexpr int repeat = 1000;
template <typename T>
inline T _sigmoid(T x) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (x < min) ? min : ((x > max) ? max : x);
return static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
template <typename T>
inline T _tanh(T x) {
return static_cast<T>(2) * _sigmoid<T>(static_cast<T>(2) * x) -
static_cast<T>(1);
}
template <typename T>
void ref_sigmoid(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = _sigmoid(x[i]);
}
}
template <typename T>
void ref_tanh(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = _tanh(x[i]);
}
}
template <typename T>
void ref_relu(const int n, const T* x, T* y) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T>
void RandomVec(const int n, T* a) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
const T lower = static_cast<T>(-20.f);
const T upper = static_cast<T>(20.f);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
template <typename T>
void TestAndBench(const int n, std::function<void(const int, const T*, T*)> tgt,
std::function<void(const int, const T*, T*)> ref) {
std::vector<T> x(n);
std::vector<T> ytgt(n), yref(n);
RandomVec<T>(n, x.data());
const T* x_data = x.data();
T* ytgt_data = ytgt.data();
T* yref_data = yref.data();
auto st = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
tgt(n, x_data, ytgt_data);
}
auto mt = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ref(n, x_data, yref_data);
}
auto et = GetCurrentUS();
VLOG(3) << "Vec size " << n << ": refer takes: " << (et - mt) / repeat
<< " us, tgt takes: " << (mt - st) / repeat;
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3);
}
}
TEST(CpuVecTest, sigmoid) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512_common>,
ref_sigmoid<float>);
}
TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
}
TEST(CpuVecTest, tanh) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestAndBench<float>(sz, vec_tanh<float, jit::avx512_common>,
ref_tanh<float>);
}
TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
}
TEST(CpuVecTest, relu) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestAndBench<float>(sz, vec_relu<float, jit::avx512_common>,
ref_relu<float>);
}
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
}
template <typename T>
void TestInplace(const int n, std::function<void(const int, const T*, T*)> tgt,
std::function<void(const int, const T*, T*)> ref) {
std::vector<T> x(n);
std::vector<T> ytgt(n), yref(n);
RandomVec<T>(n, x.data());
const T* x_data = x.data();
T* yref_data = yref.data();
T* ytgt_data = ytgt.data();
std::memcpy(yref_data, x_data, sizeof(T) * n);
std::memcpy(ytgt_data, x_data, sizeof(T) * n);
ref(n, yref_data, yref_data);
tgt(n, ytgt_data, ytgt_data);
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(ytgt_data[i], yref_data[i], 1e-3);
}
}
TEST(CpuVecTest, inplace_sigmoid) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512_common>,
ref_sigmoid<float>);
}
TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
}
TEST(CpuVecTest, inplace_tanh) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
TestInplace<float>(sz, vec_tanh<float, jit::avx512_common>,
ref_tanh<float>);
}
TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
}
TEST(CpuVecTest, inplace_relu) {
namespace jit = paddle::platform::jit;
using namespace paddle::operators::math; // NOLINT
for (auto sz : {1, 2, 15, 16, 30, 32, 128, 200, 512}) {
TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
TestInplace<float>(sz, vec_relu<float, jit::avx512_common>,
ref_relu<float>);
}
TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
}

@ -50,7 +50,7 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS malloc cc_library(device_context SRCS device_context.cc init.cc DEPS malloc
place eigen3 stringpiece cpu_helper framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
cc_test(init_test SRCS init_test.cc DEPS device_context) cc_test(init_test SRCS init_test.cc DEPS device_context)

@ -51,7 +51,7 @@ typedef enum {
} cpu_isa_t; // Instruction set architecture } cpu_isa_t; // Instruction set architecture
// May I use some instruction // May I use some instruction
inline bool MayIUse(const cpu_isa_t cpu_isa); bool MayIUse(const cpu_isa_t cpu_isa);
} // namespace jit } // namespace jit

@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
@ -120,6 +121,22 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
#ifndef PADDLE_WITH_MKLDNN #ifndef PADDLE_WITH_MKLDNN
platform::SetNumThreads(FLAGS_paddle_num_threads); platform::SetNumThreads(FLAGS_paddle_num_threads);
#endif #endif
if (platform::jit::MayIUse(platform::jit::avx512_common)) {
#ifndef __AVX512F__
LOG(WARNING) << "AVX512F is available, Please re-compile on local machine";
#endif
}
if (platform::jit::MayIUse(platform::jit::avx2)) {
#ifndef __AVX2__
LOG(WARNING) << "AVX2 is available, Please re-compile on local machine";
#endif
}
if (platform::jit::MayIUse(platform::jit::avx)) {
#ifndef __AVX__
LOG(WARNING) << "AVX is available, Please re-compile on local machine";
#endif
}
} }
void InitGLOG(const std::string &prog_name) { void InitGLOG(const std::string &prog_name) {

Loading…
Cancel
Save