Merge pull request #14958 from tensor-tang/refine/jit
	
		
	
				
					
				
			enhance jitrevert-15207-remove_op_handle_lock_and_fix_var
						commit
						693e5e65ce
					
				| @ -0,0 +1,25 @@ | ||||
| 
 | ||||
| set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) | ||||
| file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt.  DO NOT EDIT!\n\n") | ||||
| file(APPEND ${jit_file} "\#pragma once\n") | ||||
| file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") | ||||
| file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") | ||||
| 
 | ||||
| set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) | ||||
| 
 | ||||
| file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") | ||||
| list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc) | ||||
| cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) | ||||
| 
 | ||||
| # refer must go first | ||||
| add_subdirectory(refer) | ||||
| add_subdirectory(more) | ||||
| if(WITH_XBYAK) | ||||
|     add_subdirectory(gen) | ||||
| endif() | ||||
| 
 | ||||
| cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) | ||||
| cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper) | ||||
| if(NOT WIN32) | ||||
|     cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper) | ||||
| endif() | ||||
| @ -0,0 +1,231 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <random> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "gflags/gflags.h" | ||||
| #include "glog/logging.h" | ||||
| #include "paddle/fluid/operators/jit/kernels.h" | ||||
| #include "paddle/fluid/platform/device_tracer.h" | ||||
| #include "paddle/fluid/platform/place.h" | ||||
| #include "paddle/fluid/platform/port.h" | ||||
| 
 | ||||
| DEFINE_int32(burning, 10, "Burning times."); | ||||
| DEFINE_int32(repeat, 3000, "Repeat times."); | ||||
| DEFINE_int32(max_size, 1000, "The Max size would be tested."); | ||||
| 
 | ||||
| template <typename T> | ||||
| void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), | ||||
|                const T upper = static_cast<T>(20.f), unsigned int seed = 100) { | ||||
|   std::mt19937 rng(seed); | ||||
|   std::uniform_real_distribution<double> uniform_dist(0, 1); | ||||
|   for (int i = 0; i < n; ++i) { | ||||
|     a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| std::vector<int> TestSizes() { | ||||
|   std::vector<int> s; | ||||
|   for (int i = 1; i <= FLAGS_max_size; ++i) { | ||||
|     s.push_back(i); | ||||
|   } | ||||
|   return s; | ||||
| } | ||||
| 
 | ||||
| template <typename KernelTuples, typename... Args> | ||||
| struct BenchFunc { | ||||
|   // return this function avg time
 | ||||
|   double operator()(const typename KernelTuples::func_type tgt, Args... args) { | ||||
|     for (int i = 0; i < FLAGS_burning; ++i) { | ||||
|       tgt(args...); | ||||
|     } | ||||
|     auto start = paddle::platform::PosixInNsec() / 1e-3; | ||||
|     for (int i = 0; i < FLAGS_repeat; ++i) { | ||||
|       tgt(args...); | ||||
|     } | ||||
|     auto end = paddle::platform::PosixInNsec() / 1e-3; | ||||
|     return static_cast<double>(end - start) / FLAGS_repeat; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| namespace jit = paddle::operators::jit; | ||||
| 
 | ||||
| template <jit::KernelType KT, typename KernelTuples, typename PlaceType, | ||||
|           typename... Args> | ||||
| void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { | ||||
|   BenchFunc<KernelTuples, Args...> benchmark; | ||||
|   std::vector<std::pair<std::string, double>> infos; | ||||
|   // test refer
 | ||||
|   auto refer = jit::GetRefer<KT, KernelTuples>(); | ||||
|   if (!refer) { | ||||
|     LOG(FATAL) << "Refer can not be empty!"; | ||||
|   } | ||||
|   infos.push_back(std::make_pair("Refer", benchmark(refer, args...))); | ||||
| 
 | ||||
|   // test jitcode
 | ||||
|   auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr); | ||||
|   if (jitcode) { | ||||
|     infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...))); | ||||
|   } | ||||
|   // test all impls in more
 | ||||
|   jit::KernelKey kkey(KT, PlaceType()); | ||||
|   auto& pool = jit::KernelPool().Instance().AllKernels(); | ||||
|   auto iter = pool.find(kkey); | ||||
|   if (iter != pool.end()) { | ||||
|     auto& impls = iter->second; | ||||
|     for (auto& impl : impls) { | ||||
|       auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get()); | ||||
|       if (i && i->UseMe(attr)) { | ||||
|         auto more = i->GetFunc(); | ||||
|         infos.push_back( | ||||
|             std::make_pair(i->ImplType(), benchmark(more, args...))); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   // Test result from Get function
 | ||||
|   auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr); | ||||
|   if (!tgt) { | ||||
|     LOG(FATAL) << "Target can not be empty!"; | ||||
|   } | ||||
|   infos.push_back(std::make_pair("Target", benchmark(tgt, args...))); | ||||
| 
 | ||||
|   // print
 | ||||
|   std::ostringstream loginfos; | ||||
|   loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": "; | ||||
|   for (auto pair : infos) { | ||||
|     loginfos << pair.first << " takes " << pair.second << " us; "; | ||||
|   } | ||||
|   LOG(INFO) << loginfos.str(); | ||||
| } | ||||
| 
 | ||||
| template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> | ||||
| void BenchXYZNKernel() { | ||||
|   for (int d : TestSizes()) { | ||||
|     std::vector<T> x(d), y(d), z(d); | ||||
|     RandomVec<T>(d, x.data()); | ||||
|     RandomVec<T>(d, y.data()); | ||||
|     BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(), | ||||
|                                                      z.data(), d); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> | ||||
| void BenchAXYNKernel() { | ||||
|   for (int d : TestSizes()) { | ||||
|     const T a = static_cast<T>(3); | ||||
|     std::vector<T> x(d), y(d); | ||||
|     RandomVec<T>(d, x.data()); | ||||
|     BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(), | ||||
|                                                      d); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> | ||||
| void BenchXYNKernel() { | ||||
|   for (int d : TestSizes()) { | ||||
|     std::vector<T> x(d), y(d); | ||||
|     RandomVec<T>(d, x.data()); | ||||
|     BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> | ||||
| void BenchLSTMKernel() { | ||||
|   for (bool use_peephole : {true, false}) { | ||||
|     for (int d : TestSizes()) { | ||||
|       const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh, | ||||
|                                   use_peephole); | ||||
|       std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d); | ||||
|       RandomVec<T>(4 * d, x.data(), -2.f, 2.f); | ||||
|       RandomVec<T>(3 * d, wp.data(), -2.f, 2.f); | ||||
|       RandomVec<T>(d, ct_1.data(), -2.f, 2.f); | ||||
|       const T* ct_1_data = ct_1.data(); | ||||
|       const T* wp_data = wp.data(); | ||||
|       T* x_data = x.data(); | ||||
|       T* checked_data = checked.data(); | ||||
|       T* ct_data = ct.data(); | ||||
|       T* ht_data = ht.data(); | ||||
|       jit::lstm_t step; | ||||
|       step.gates = x_data; | ||||
|       step.ct_1 = ct_1_data; | ||||
|       step.ct = ct_data; | ||||
|       step.ht = ht_data; | ||||
|       if (use_peephole) { | ||||
|         step.wp = wp_data; | ||||
|         step.checked = checked_data; | ||||
|       } | ||||
|       BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> | ||||
| void BenchGRUKernel() { | ||||
|   for (int d : TestSizes()) { | ||||
|     const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh); | ||||
|     std::vector<T> x(3 * d), ht_1(d), ht(d); | ||||
|     RandomVec<T>(3 * d, x.data(), -2.f, 2.f); | ||||
|     RandomVec<T>(d, ht_1.data(), -2.f, 2.f); | ||||
|     const T* ht_1_data = ht_1.data(); | ||||
|     T* x_data = x.data(); | ||||
|     T* ht_data = ht.data(); | ||||
|     jit::gru_t step; | ||||
|     step.gates = x_data; | ||||
|     step.ht_1 = ht_1_data; | ||||
|     step.ht = ht_data; | ||||
|     BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // Benchmark all jit kernels including jitcode, mkl and refer.
 | ||||
| // To use this tool, run command: ./benchmark [options...]
 | ||||
| // Options:
 | ||||
| //     --burning: the burning time before count
 | ||||
| //     --repeat: the repeat times
 | ||||
| //     --max_size: the max size would be tested
 | ||||
| int main(int argc, char* argv[]) { | ||||
|   gflags::ParseCommandLineFlags(&argc, &argv, true); | ||||
|   google::InitGoogleLogging(argv[0]); | ||||
|   LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat | ||||
|             << " times."; | ||||
|   using T = float; | ||||
|   using PlaceType = paddle::platform::CPUPlace; | ||||
|   // xyzn
 | ||||
|   BenchXYZNKernel<jit::kVMul, T, PlaceType>(); | ||||
|   BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); | ||||
|   BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); | ||||
|   BenchXYZNKernel<jit::kVSub, T, PlaceType>(); | ||||
| 
 | ||||
|   // axyn
 | ||||
|   BenchAXYNKernel<jit::kVScal, T, PlaceType>(); | ||||
|   BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); | ||||
| 
 | ||||
|   // xyn
 | ||||
|   BenchXYNKernel<jit::kVRelu, T, PlaceType>(); | ||||
|   BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); | ||||
|   BenchXYNKernel<jit::kVExp, T, PlaceType>(); | ||||
|   BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); | ||||
|   BenchXYNKernel<jit::kVTanh, T, PlaceType>(); | ||||
| 
 | ||||
|   // lstm and peephole
 | ||||
|   BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); | ||||
|   BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); | ||||
| 
 | ||||
|   // gru functions
 | ||||
|   BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); | ||||
|   BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); | ||||
|   BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); | ||||
| } | ||||
| @ -0,0 +1,28 @@ | ||||
| 
 | ||||
| file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") | ||||
| 
 | ||||
| cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak) | ||||
| set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE) | ||||
| 
 | ||||
| function(USE_JITKERNEL_GEN TARGET) | ||||
|     file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n") | ||||
| endfunction() | ||||
| 
 | ||||
| # use gen jitcode kernel by name | ||||
| USE_JITKERNEL_GEN(kVMul) | ||||
| USE_JITKERNEL_GEN(kVAdd) | ||||
| #USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me | ||||
| USE_JITKERNEL_GEN(kVAddRelu) | ||||
| USE_JITKERNEL_GEN(kVScal) | ||||
| USE_JITKERNEL_GEN(kVAddBias) | ||||
| USE_JITKERNEL_GEN(kVRelu) | ||||
| USE_JITKERNEL_GEN(kVIdentity) | ||||
| USE_JITKERNEL_GEN(kVExp) | ||||
| USE_JITKERNEL_GEN(kVSigmoid) | ||||
| USE_JITKERNEL_GEN(kVTanh) | ||||
| USE_JITKERNEL_GEN(kLSTMCtHt) | ||||
| USE_JITKERNEL_GEN(kLSTMC1H1) | ||||
| USE_JITKERNEL_GEN(kGRUH1) | ||||
| USE_JITKERNEL_GEN(kGRUHtPart1) | ||||
| USE_JITKERNEL_GEN(kGRUHtPart2) | ||||
| USE_JITKERNEL_GEN(kNCHW16CMulNC) | ||||
| @ -0,0 +1,135 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/gen/act.h" | ||||
| #include "paddle/fluid/operators/jit/registry.h" | ||||
| #include "paddle/fluid/platform/cpu_info.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = { | ||||
|     REPEAT_8TIMES(1.f), | ||||
|     REPEAT_8TIMES(2.f), | ||||
|     REPEAT_8TIMES(0.5f), | ||||
|     REPEAT_8TIMES(EXP_HIG), | ||||
|     REPEAT_8TIMES(EXP_LOW), | ||||
|     REPEAT_8TIMES(CEPHES_LOG2EF), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_C1), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_C2), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P0), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P1), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P2), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P3), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P4), | ||||
|     REPEAT_8TIMES(CEPHES_EXP_P5), | ||||
|     REPEAT_8TIMES(EXP_MAX_INPUT), | ||||
|     REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), | ||||
|     REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; | ||||
| 
 | ||||
| const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)}; | ||||
| int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0}; | ||||
| 
 | ||||
| void VActJitCode::genCode() { | ||||
|   int offset = 0; | ||||
|   for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { | ||||
|     vmovups(ymm_src, ptr[param1 + offset]); | ||||
|     act<ymm_t>(ymm_dst, ymm_src, type_); | ||||
|     vmovups(ptr[param2 + offset], ymm_dst); | ||||
|     offset += sizeof(float) * YMM_FLOAT_BLOCK; | ||||
|   } | ||||
|   int rest = num_ % YMM_FLOAT_BLOCK; | ||||
|   while (rest > 0) { | ||||
|     int block = XMM_FLOAT_BLOCK; | ||||
|     if (rest >= 4) { | ||||
|       block = 4; | ||||
|       vmovups(xmm_src, ptr[param1 + offset]); | ||||
|     } else if (rest >= 2) { | ||||
|       block = 2; | ||||
|       vmovq(xmm_src, ptr[param1 + offset]); | ||||
|     } else { | ||||
|       block = 1; | ||||
|       vmovss(xmm_src, ptr[param1 + offset]); | ||||
|     } | ||||
|     act<xmm_t>(xmm_dst, xmm_src, type_); | ||||
|     if (rest >= 4) { | ||||
|       vmovups(ptr[param2 + offset], xmm_dst); | ||||
|     } else if (rest >= 2) { | ||||
|       vmovq(ptr[param2 + offset], xmm_dst); | ||||
|     } else { | ||||
|       vmovss(ptr[param2 + offset], xmm_dst); | ||||
|     } | ||||
|     offset += sizeof(float) * block; | ||||
|     rest -= block; | ||||
|   } | ||||
|   ret(); | ||||
| } | ||||
| 
 | ||||
| #define DECLARE_ACT_CREATOR(name)                                            \ | ||||
|   class name##Creator : public JitCodeCreator<int> {                         \ | ||||
|    public:                                                                   \ | ||||
|     bool UseMe(const int& attr) const override {                             \ | ||||
|       return platform::MayIUse(platform::avx);                               \ | ||||
|     }                                                                        \ | ||||
|     size_t CodeSize(const int& d) const override;                            \ | ||||
|     std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ | ||||
|       return make_unique<name##JitCode>(attr, CodeSize(attr));               \ | ||||
|     }                                                                        \ | ||||
|   } | ||||
| 
 | ||||
| DECLARE_ACT_CREATOR(VRelu); | ||||
| DECLARE_ACT_CREATOR(VIdentity); | ||||
| DECLARE_ACT_CREATOR(VExp); | ||||
| DECLARE_ACT_CREATOR(VSigmoid); | ||||
| DECLARE_ACT_CREATOR(VTanh); | ||||
| 
 | ||||
| // TODO(TJ): tuning use me
 | ||||
| size_t VReluCreator::CodeSize(const int& d) const { | ||||
|   return 96 /* init size */ + | ||||
|          (d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ * | ||||
|              8 /* average bytes for each instruction */; | ||||
| } | ||||
| 
 | ||||
| size_t VIdentityCreator::CodeSize(const int& d) const { | ||||
|   return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8; | ||||
| } | ||||
| 
 | ||||
| size_t VExpCreator::CodeSize(const int& d) const { | ||||
|   return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8; | ||||
| } | ||||
| 
 | ||||
| size_t VSigmoidCreator::CodeSize(const int& d) const { | ||||
|   return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8; | ||||
| } | ||||
| 
 | ||||
| size_t VTanhCreator::CodeSize(const int& d) const { | ||||
|   return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8; | ||||
| } | ||||
| 
 | ||||
| #undef DECLARE_ACT_CREATOR | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace gen = paddle::operators::jit::gen; | ||||
| 
 | ||||
| REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator); | ||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								| @ -0,0 +1,186 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/gen/blas.h" | ||||
| #include "paddle/fluid/operators/jit/registry.h" | ||||
| #include "paddle/fluid/platform/cpu_info.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| void VXXJitCode::genCode() { | ||||
|   // do not need push stack, and do not need save avx512reg if do not use avx512
 | ||||
|   int offset = 0; | ||||
|   if (with_relu_) { | ||||
|     vxorps(ymm_zero, ymm_zero, ymm_zero); | ||||
|   } | ||||
|   if (scalar_index_ == 1) { | ||||
|     vbroadcastss(ymm_src1, ptr[param1]); | ||||
|   } else if (scalar_index_ == 2) { | ||||
|     vbroadcastss(ymm_src2, ptr[param2]); | ||||
|   } | ||||
|   for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { | ||||
|     if (scalar_index_ != 1) { | ||||
|       vmovups(ymm_src1, ptr[param1 + offset]); | ||||
|     } | ||||
|     if (scalar_index_ != 2) { | ||||
|       vmovups(ymm_src2, ptr[param2 + offset]); | ||||
|     } | ||||
|     if (type_ == operand_type::MUL) { | ||||
|       vmulps(ymm_dst, ymm_src1, ymm_src2); | ||||
|     } else if (type_ == operand_type::ADD) { | ||||
|       vaddps(ymm_dst, ymm_src1, ymm_src2); | ||||
|     } | ||||
|     if (with_relu_) { | ||||
|       vmaxps(ymm_dst, ymm_zero, ymm_dst); | ||||
|     } | ||||
|     vmovups(ptr[param3 + offset], ymm_dst); | ||||
|     offset += sizeof(float) * YMM_FLOAT_BLOCK; | ||||
|   } | ||||
|   int rest = num_ % YMM_FLOAT_BLOCK; | ||||
|   while (rest > 0) { | ||||
|     int block = XMM_FLOAT_BLOCK; | ||||
|     if (rest >= 4) { | ||||
|       block = 4; | ||||
|       if (scalar_index_ != 1) { | ||||
|         vmovups(xmm_src1, ptr[param1 + offset]); | ||||
|       } | ||||
|       if (scalar_index_ != 2) { | ||||
|         vmovups(xmm_src2, ptr[param2 + offset]); | ||||
|       } | ||||
|     } else if (rest >= 2) { | ||||
|       block = 2; | ||||
|       if (scalar_index_ != 1) { | ||||
|         vmovq(xmm_src1, ptr[param1 + offset]); | ||||
|       } | ||||
|       if (scalar_index_ != 2) { | ||||
|         vmovq(xmm_src2, ptr[param2 + offset]); | ||||
|       } | ||||
|     } else { | ||||
|       block = 1; | ||||
|       if (scalar_index_ != 1) { | ||||
|         vmovss(xmm_src1, ptr[param1 + offset]); | ||||
|       } | ||||
|       if (scalar_index_ != 2) { | ||||
|         vmovss(xmm_src2, ptr[param2 + offset]); | ||||
|       } | ||||
|     } | ||||
|     switch (type_) { | ||||
|       case operand_type::MUL: | ||||
|         vmulps(xmm_dst, xmm_src1, xmm_src2); | ||||
|         break; | ||||
|       case operand_type::ADD: | ||||
|         vaddps(xmm_dst, xmm_src1, xmm_src2); | ||||
|         break; | ||||
|       default: | ||||
|         break; | ||||
|     } | ||||
|     if (with_relu_) { | ||||
|       vmaxps(xmm_dst, xmm_zero, xmm_dst); | ||||
|     } | ||||
|     if (rest >= 4) { | ||||
|       vmovups(ptr[param3 + offset], xmm_dst); | ||||
|     } else if (rest >= 2) { | ||||
|       vmovq(ptr[param3 + offset], xmm_dst); | ||||
|     } else { | ||||
|       vmovss(ptr[param3 + offset], xmm_dst); | ||||
|     } | ||||
|     offset += sizeof(float) * block; | ||||
|     rest -= block; | ||||
|   } | ||||
|   ret(); | ||||
| } | ||||
| 
 | ||||
| void NCHW16CMulNCJitCode::genCode() { | ||||
|   // RDI is ptr x_input
 | ||||
|   // RSI is ptr y_input
 | ||||
|   // RDX is ptr output
 | ||||
|   // RCX is height
 | ||||
|   // r8 is width
 | ||||
| 
 | ||||
|   push(rbx); | ||||
| 
 | ||||
|   xor_(rax, rax); | ||||
|   xor_(r10, r10); | ||||
|   vmovups(zmm3, ptr[rsi]); | ||||
| 
 | ||||
|   L("h_loop"); | ||||
|   xor_(rbx, rbx); | ||||
|   L("w_loop"); | ||||
|   vmovups(zmm2, ptr[rdi + rax]); | ||||
|   vmulps(zmm1, zmm2, zmm3); | ||||
|   vmovups(ptr[rdx + rax], zmm1); | ||||
|   add(rax, 64); | ||||
|   inc(rbx); | ||||
|   cmp(r8, rbx); | ||||
|   jnz("w_loop"); | ||||
|   inc(r10); | ||||
|   cmp(r10, rcx); | ||||
|   jnz("h_loop"); | ||||
| 
 | ||||
|   pop(rbx); | ||||
|   ret(); | ||||
| } | ||||
| 
 | ||||
| class NCHW16CMulNCCreator : public JitCodeCreator<int> { | ||||
|  public: | ||||
|   bool UseMe(const int& attr) const override { | ||||
|     return platform::MayIUse(platform::avx512f); | ||||
|   } | ||||
|   size_t CodeSize(const int& d) const override { return 256 * 1024; } | ||||
|   std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { | ||||
|     return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr)); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| #define DECLARE_BLAS_CREATOR(name)                                           \ | ||||
|   class name##Creator : public JitCodeCreator<int> {                         \ | ||||
|    public:                                                                   \ | ||||
|     bool UseMe(const int& attr) const override {                             \ | ||||
|       return platform::MayIUse(platform::avx);                               \ | ||||
|     }                                                                        \ | ||||
|     size_t CodeSize(const int& d) const override {                           \ | ||||
|       return 96 + d / YMM_FLOAT_BLOCK * 4 * 8;                               \ | ||||
|     }                                                                        \ | ||||
|     std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ | ||||
|       return make_unique<name##JitCode>(attr, CodeSize(attr));               \ | ||||
|     }                                                                        \ | ||||
|   } | ||||
| 
 | ||||
| DECLARE_BLAS_CREATOR(VMul); | ||||
| DECLARE_BLAS_CREATOR(VAdd); | ||||
| DECLARE_BLAS_CREATOR(VSub); | ||||
| DECLARE_BLAS_CREATOR(VAddRelu); | ||||
| DECLARE_BLAS_CREATOR(VScal); | ||||
| DECLARE_BLAS_CREATOR(VAddBias); | ||||
| 
 | ||||
| #undef DECLARE_BLAS_CREATOR | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace gen = paddle::operators::jit::gen; | ||||
| 
 | ||||
| REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator); | ||||
| // TODO(TJ): enable sub
 | ||||
| // REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
 | ||||
| REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator); | ||||
| REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator); | ||||
| REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator); | ||||
| @ -0,0 +1,117 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <string> | ||||
| #include "glog/logging.h" | ||||
| #include "paddle/fluid/operators/jit/gen/jitcode.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
 | ||||
| class VXXJitCode : public JitCode { | ||||
|  public: | ||||
|   explicit VXXJitCode(int d, operand_type type, int scalar_index, | ||||
|                       bool with_relu, size_t code_size = 256 * 1024, | ||||
|                       void* code_ptr = nullptr) | ||||
|       : JitCode(code_size, code_ptr), | ||||
|         num_(d), | ||||
|         type_(type), | ||||
|         scalar_index_(scalar_index), | ||||
|         with_relu_(with_relu) { | ||||
|     if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) { | ||||
|       LOG(FATAL) << "Do not support this operand type: " << type_; | ||||
|     } | ||||
|     this->genCode(); | ||||
|   } | ||||
| 
 | ||||
|   virtual const char* name() const { | ||||
|     std::string base = "VXXJitCode"; | ||||
|     if (scalar_index_ == 1) { | ||||
|       base += "_Scalar"; | ||||
|     } else { | ||||
|       base += "_Vec"; | ||||
|     } | ||||
|     if (type_ == operand_type::MUL) { | ||||
|       base += "_Mul"; | ||||
|     } else if (type_ == operand_type::ADD) { | ||||
|       base += "_Add"; | ||||
|     } | ||||
|     if (scalar_index_ == 2) { | ||||
|       base += "_Scalar"; | ||||
|     } else { | ||||
|       base += "_Vec"; | ||||
|     } | ||||
|     base += (with_relu_ ? "_Relu" : ""); | ||||
|     return base.c_str(); | ||||
|   } | ||||
|   void genCode() override; | ||||
| 
 | ||||
|  private: | ||||
|   int num_; | ||||
|   operand_type type_; | ||||
|   int scalar_index_; | ||||
|   bool with_relu_; | ||||
|   reg64_t param1{abi_param1}; | ||||
|   reg64_t param2{abi_param2}; | ||||
|   reg64_t param3{abi_param3}; | ||||
| 
 | ||||
|   xmm_t xmm_src1 = xmm_t(0); | ||||
|   xmm_t xmm_src2 = xmm_t(1); | ||||
|   xmm_t xmm_dst = xmm_t(2); | ||||
|   xmm_t xmm_zero = xmm_t(3); | ||||
| 
 | ||||
|   ymm_t ymm_src1 = ymm_t(0); | ||||
|   ymm_t ymm_src2 = ymm_t(1); | ||||
|   ymm_t ymm_dst = ymm_t(2); | ||||
|   ymm_t ymm_zero = ymm_t(3); | ||||
| }; | ||||
| 
 | ||||
| #define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu)             \ | ||||
|   class name##JitCode : public VXXJitCode {                                    \ | ||||
|    public:                                                                     \ | ||||
|     explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr)  \ | ||||
|         : VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \ | ||||
|     }                                                                          \ | ||||
|   }; | ||||
| 
 | ||||
| DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false); | ||||
| DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false); | ||||
| DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false); | ||||
| DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true); | ||||
| DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false); | ||||
| DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false); | ||||
| 
 | ||||
| #undef DECLARE_BLAS_JITCODE | ||||
| 
 | ||||
| // nChw16c = nChw16c .* NC
 | ||||
| class NCHW16CMulNCJitCode : public JitCode { | ||||
|  public: | ||||
|   DECLARE_JIT_CODE(NCHW16CMulNCJitCode); | ||||
|   explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size, | ||||
|                                void* code_ptr = nullptr) | ||||
|       : JitCode(code_size, code_ptr) { | ||||
|     this->genCode(); | ||||
|   } | ||||
|   void genCode() override; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,116 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/gen/gru.h" | ||||
| #include <stddef.h>  // offsetof | ||||
| #include "paddle/fluid/operators/jit/registry.h" | ||||
| #include "paddle/fluid/platform/cpu_info.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| void GRUJitCode::genCode() { | ||||
|   reg64_t reg_ptr_gates = rax; | ||||
|   reg64_t reg_ptr_ht_1 = r9; | ||||
|   reg64_t reg_ptr_ht = r10; | ||||
|   mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]); | ||||
|   mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]); | ||||
|   mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]); | ||||
|   ymm_t ymm_one = ymm_t(0); | ||||
| 
 | ||||
|   if (id_ == 2) { | ||||
|     reg64_t reg_ptr_tmp = r11; | ||||
|     mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts)); | ||||
|     vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); | ||||
|   } | ||||
|   int offset = 0; | ||||
|   int d = num_ * sizeof(float); | ||||
|   for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { | ||||
|     ymm_t ymm_u = ymm_t(1); | ||||
|     ymm_t ymm_r = ymm_t(2); | ||||
|     ymm_t ymm_s = ymm_t(3); | ||||
|     ymm_t ymm_ht_1 = ymm_t(4); | ||||
|     // W: {W_update, W_reset; W_state}
 | ||||
|     if (id_ == 0 || id_ == 2) { | ||||
|       vmovups(ymm_u, ptr[reg_ptr_gates + offset]); | ||||
|       vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]); | ||||
|     } | ||||
|     if (id_ == 1) { | ||||
|       vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]); | ||||
|     } | ||||
|     if (id_ == 1 || id_ == 2) { | ||||
|       vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]); | ||||
|     } | ||||
| 
 | ||||
|     if (id_ == 0) { | ||||
|       // ht = act_gate(u) * act_cand(s)
 | ||||
|       act<ymm_t>(ymm_u, ymm_u, act_gate_); | ||||
|       act<ymm_t>(ymm_s, ymm_s, act_cand_); | ||||
|       vmulps(ymm_s, ymm_s, ymm_u); | ||||
|       vmovups(ptr[reg_ptr_ht + offset], ymm_s); | ||||
|     } else if (id_ == 1) { | ||||
|       // ht = act_gate(r) * ht_1
 | ||||
|       act<ymm_t>(ymm_r, ymm_r, act_gate_); | ||||
|       vmulps(ymm_r, ymm_r, ymm_ht_1); | ||||
|       vmovups(ptr[reg_ptr_ht + offset], ymm_r); | ||||
|     } else if (id_ == 2) { | ||||
|       // ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
 | ||||
|       ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx()); | ||||
|       act<ymm_t>(ymm_u, ymm_u, act_gate_); | ||||
|       act<ymm_t>(ymm_s, ymm_s, act_cand_); | ||||
|       vmulps(ymm_s, ymm_s, ymm_u); | ||||
|       vsubps(ymm_u, ymm_one_inner, ymm_u); | ||||
|       vmulps(ymm_u, ymm_ht_1, ymm_u); | ||||
|       vaddps(ymm_u, ymm_s, ymm_u); | ||||
|       vmovups(ptr[reg_ptr_ht + offset], ymm_u); | ||||
|     } | ||||
|     offset += sizeof(float) * YMM_FLOAT_BLOCK; | ||||
|   } | ||||
|   ret(); | ||||
| } | ||||
| 
 | ||||
| #define DECLARE_GRU_CREATOR(name)                                 \ | ||||
|   class name##Creator : public JitCodeCreator<gru_attr_t> {       \ | ||||
|    public:                                                        \ | ||||
|     /* TODO(TJ): enable more */                                   \ | ||||
|     bool UseMe(const gru_attr_t& attr) const override {           \ | ||||
|       return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ | ||||
|     }                                                             \ | ||||
|     size_t CodeSize(const gru_attr_t& attr) const override {      \ | ||||
|       return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8;          \ | ||||
|     }                                                             \ | ||||
|     std::unique_ptr<GenBase> CreateJitCode(                       \ | ||||
|         const gru_attr_t& attr) const override {                  \ | ||||
|       return make_unique<name##JitCode>(attr, CodeSize(attr));    \ | ||||
|     }                                                             \ | ||||
|   } | ||||
| 
 | ||||
| DECLARE_GRU_CREATOR(GRUH1); | ||||
| DECLARE_GRU_CREATOR(GRUHtPart1); | ||||
| DECLARE_GRU_CREATOR(GRUHtPart2); | ||||
| 
 | ||||
| #undef DECLARE_GRU_CREATOR | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace gen = paddle::operators::jit::gen; | ||||
| 
 | ||||
| REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator); | ||||
| REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator); | ||||
| REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator); | ||||
| @ -0,0 +1,113 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <string> | ||||
| #include "glog/logging.h" | ||||
| #include "paddle/fluid/operators/jit/gen/act.h" | ||||
| #include "paddle/fluid/operators/jit/gen/jitcode.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| class GRUJitCode : public VActFunc { | ||||
|  public: | ||||
|   explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size, | ||||
|                       void* code_ptr = nullptr) | ||||
|       : VActFunc(code_size, code_ptr), id_(id), num_(attr.d) { | ||||
|     auto typeExchange = [](KernelType type) -> gen::operand_type { | ||||
|       if (type == KernelType::kVSigmoid) { | ||||
|         return operand_type::SIGMOID; | ||||
|       } else if (type == KernelType::kVRelu) { | ||||
|         return operand_type::RELU; | ||||
|       } else if (type == KernelType::kVTanh) { | ||||
|         return operand_type::TANH; | ||||
|       } else if (type == KernelType::kVIdentity) { | ||||
|         return operand_type::IDENTITY; | ||||
|       } else { | ||||
|         LOG(FATAL) << "Do not support this jit::KernelType: " << type; | ||||
|       } | ||||
|       return operand_type::IDENTITY; | ||||
|     }; | ||||
|     act_gate_ = typeExchange(attr.act_gate); | ||||
|     act_cand_ = typeExchange(attr.act_cand); | ||||
| 
 | ||||
|     this->genCode(); | ||||
|   } | ||||
| 
 | ||||
|   const char* name() const override { | ||||
|     std::string base = "GRUJitCode"; | ||||
|     if (id_ == 0) { | ||||
|       base += "_H1"; | ||||
|     } else if (id_ == 1) { | ||||
|       base += "_HtPart1"; | ||||
|     } else if (id_ == 2) { | ||||
|       base += "_HtPart2"; | ||||
|     } | ||||
|     auto AddTypeStr = [&](operand_type type) { | ||||
|       switch (type) { | ||||
|         case operand_type::RELU: | ||||
|           base += "_Relu"; | ||||
|           break; | ||||
|         case operand_type::EXP: | ||||
|           base += "_Exp"; | ||||
|           break; | ||||
|         case operand_type::SIGMOID: | ||||
|           base += "_Sigmoid"; | ||||
|           break; | ||||
|         case operand_type::TANH: | ||||
|           base += "_Tanh"; | ||||
|           break; | ||||
|         case operand_type::IDENTITY: | ||||
|           base += "_Identity"; | ||||
|           break; | ||||
|         default: | ||||
|           break; | ||||
|       } | ||||
|     }; | ||||
|     AddTypeStr(act_gate_); | ||||
|     AddTypeStr(act_cand_); | ||||
|     return base.c_str(); | ||||
|   } | ||||
|   void genCode() override; | ||||
| 
 | ||||
|  protected: | ||||
|   int id_; | ||||
|   int num_; | ||||
|   operand_type act_gate_; | ||||
|   operand_type act_cand_; | ||||
|   reg64_t param1{abi_param1}; | ||||
| }; | ||||
| 
 | ||||
| #define DECLARE_GRU_JITCODE(name, id)                                \ | ||||
|   class name##JitCode : public GRUJitCode {                          \ | ||||
|    public:                                                           \ | ||||
|     explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \ | ||||
|                            void* code_ptr = nullptr)                 \ | ||||
|         : GRUJitCode(id, attr, code_size, code_ptr) {}               \ | ||||
|   }; | ||||
| 
 | ||||
| DECLARE_GRU_JITCODE(GRUH1, 0); | ||||
| DECLARE_GRU_JITCODE(GRUHtPart1, 1); | ||||
| DECLARE_GRU_JITCODE(GRUHtPart2, 2); | ||||
| 
 | ||||
| #undef DECLARE_GRU_JITCODE | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,126 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <type_traits> | ||||
| #include "paddle/fluid/operators/jit/gen_base.h" | ||||
| #include "paddle/fluid/platform/cpu_info.h" | ||||
| 
 | ||||
| #define XBYAK_USE_MMAP_ALLOCATOR | ||||
| #include "xbyak/xbyak.h" | ||||
| #include "xbyak/xbyak_util.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| // Application Binary Interface
 | ||||
| constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), | ||||
|     abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), | ||||
|     abi_param4(Xbyak::Operand::RCX); | ||||
| 
 | ||||
| constexpr Xbyak::Operand::Code g_abi_regs[] = { | ||||
|     Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, | ||||
|     Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15}; | ||||
| 
 | ||||
| constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]); | ||||
| 
 | ||||
| using reg64_t = const Xbyak::Reg64; | ||||
| using reg32_t = const Xbyak::Reg32; | ||||
| using xmm_t = const Xbyak::Xmm; | ||||
| using ymm_t = const Xbyak::Ymm; | ||||
| using zmm_t = const Xbyak::Zmm; | ||||
| using Label = Xbyak::Label; | ||||
| 
 | ||||
| typedef enum { | ||||
|   MUL = 0, | ||||
|   ADD, | ||||
|   SUB, | ||||
|   RELU, | ||||
|   EXP, | ||||
|   SIGMOID, | ||||
|   TANH, | ||||
|   IDENTITY | ||||
| } operand_type; | ||||
| 
 | ||||
| #define DECLARE_JIT_CODE(codename) \ | ||||
|   const char* name() const override { return #codename; } | ||||
| 
 | ||||
| class JitCode : public GenBase, public Xbyak::CodeGenerator { | ||||
|  public: | ||||
|   explicit JitCode(size_t code_size, void* code_ptr = nullptr) | ||||
|       : Xbyak::CodeGenerator( | ||||
|             (code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size), | ||||
|             code_ptr) {} | ||||
| 
 | ||||
|   virtual const char* name() const = 0; | ||||
|   virtual void genCode() = 0; | ||||
| 
 | ||||
|   size_t getSize() const override { return CodeGenerator::getSize(); } | ||||
|   const unsigned char* getCodeInternal() override { | ||||
|     const Xbyak::uint8* code = CodeGenerator::getCode(); | ||||
|     return code; | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   Xbyak::Reg64 param1{abi_param1}; | ||||
|   const int EVEX_max_8b_offt = 0x200; | ||||
|   const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; | ||||
| 
 | ||||
|   virtual void preCode() { | ||||
|     for (int i = 0; i < num_g_abi_regs; ++i) { | ||||
|       push(Xbyak::Reg64(g_abi_regs[i])); | ||||
|     } | ||||
|     if (platform::MayIUse(platform::avx512f)) { | ||||
|       mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); | ||||
|     } | ||||
|   } | ||||
|   virtual void postCode() { | ||||
|     for (int i = 0; i < num_g_abi_regs; ++i) { | ||||
|       pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i])); | ||||
|     } | ||||
|     ret(); | ||||
|   } | ||||
|   void L(const char* label) { Xbyak::CodeGenerator::L(label); } | ||||
|   void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } | ||||
|   // Enhanced vector extension
 | ||||
|   Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt, | ||||
|                                     bool bcast = false) { | ||||
|     int scale = 0; | ||||
|     // Learn from https://github.com/intel/mkl-dnn
 | ||||
|     if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { | ||||
|       offt = offt - 2 * EVEX_max_8b_offt; | ||||
|       scale = 1; | ||||
|     } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { | ||||
|       offt = offt - 4 * EVEX_max_8b_offt; | ||||
|       scale = 2; | ||||
|     } | ||||
|     auto re = Xbyak::RegExp() + base + offt; | ||||
|     if (scale) { | ||||
|       re = re + reg_EVEX_max_8b_offt * scale; | ||||
|     } | ||||
|     if (bcast) { | ||||
|       return zword_b[re]; | ||||
|     } else { | ||||
|       return zword[re]; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,142 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/gen/lstm.h" | ||||
| #include <stddef.h>  // offsetof | ||||
| #include "paddle/fluid/operators/jit/registry.h" | ||||
| #include "paddle/fluid/platform/cpu_info.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| void LSTMJitCode::genCode() { | ||||
|   if (use_peephole_) { | ||||
|     preCode(); | ||||
|   } | ||||
|   reg64_t reg_ptr_gates = rax; | ||||
|   reg64_t reg_ptr_ct_1 = r9; | ||||
|   reg64_t reg_ptr_ct = r10; | ||||
|   reg64_t reg_ptr_ht = r11; | ||||
|   reg64_t reg_ptr_wp = r12; | ||||
|   mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]); | ||||
|   mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]); | ||||
|   mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]); | ||||
|   mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]); | ||||
|   if (use_peephole_) { | ||||
|     mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]); | ||||
|   } | ||||
| 
 | ||||
|   int offset = 0; | ||||
|   int d = num_ * sizeof(float); | ||||
|   for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { | ||||
|     /* gates: W_ch, W_ih, W_fh, W_oh */ | ||||
|     ymm_t ymm_c = ymm_t(0); | ||||
|     ymm_t ymm_i = ymm_t(1); | ||||
|     ymm_t ymm_f = ymm_t(2); | ||||
|     ymm_t ymm_o = ymm_t(3); | ||||
|     ymm_t ymm_ct_1 = ymm_t(4); | ||||
|     ymm_t ymm_wp0 = ymm_t(5); | ||||
|     ymm_t ymm_wp1 = ymm_t(6); | ||||
|     ymm_t ymm_wp2 = ymm_t(7); | ||||
|     vmovups(ymm_c, ptr[reg_ptr_gates + offset]); | ||||
|     vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]); | ||||
|     vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]); | ||||
|     vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]); | ||||
|     if (!compute_c1h1_) { | ||||
|       vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]); | ||||
|     } | ||||
|     if (use_peephole_) { | ||||
|       vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]); | ||||
|       vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]); | ||||
|       vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]); | ||||
|     } | ||||
|     /* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */ | ||||
|     // act_cand(c)
 | ||||
|     act<ymm_t>(ymm_c, ymm_c, act_cand_); | ||||
|     // act_gate(i) or act_gate(ct_1 * wp0 + i)
 | ||||
|     if (!compute_c1h1_ && use_peephole_) { | ||||
|       vmulps(ymm_wp0, ymm_ct_1, ymm_wp0); | ||||
|       vaddps(ymm_i, ymm_i, ymm_wp0); | ||||
|     } | ||||
|     act<ymm_t>(ymm_i, ymm_i, act_gate_); | ||||
|     vmulps(ymm_c, ymm_c, ymm_i); | ||||
|     if (!compute_c1h1_) { | ||||
|       // act_gate(f) or act_gate(ct_1 * wp1 + f)
 | ||||
|       if (use_peephole_) { | ||||
|         vmulps(ymm_wp1, ymm_ct_1, ymm_wp1); | ||||
|         vaddps(ymm_f, ymm_f, ymm_wp1); | ||||
|       } | ||||
|       act<ymm_t>(ymm_f, ymm_f, act_gate_); | ||||
|       // ct
 | ||||
|       vmulps(ymm_f, ymm_f, ymm_ct_1); | ||||
|       vaddps(ymm_f, ymm_f, ymm_c); | ||||
|     } | ||||
|     /* H_t = act_cell(C_t) * act_gate(o) */ | ||||
|     // act_cell(C_t)
 | ||||
|     ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f; | ||||
|     ymm_t ymm_tmp = ymm_i; | ||||
|     act<ymm_t>(ymm_tmp, ymm_ct, act_cell_); | ||||
|     // act_gate(o) or act_gate(ct * wp2 + o)
 | ||||
|     if (use_peephole_) { | ||||
|       vmulps(ymm_wp2, ymm_ct, ymm_wp2); | ||||
|       vaddps(ymm_o, ymm_o, ymm_wp2); | ||||
|     } | ||||
|     act<ymm_t>(ymm_o, ymm_o, act_gate_); | ||||
|     // ht
 | ||||
|     vmulps(ymm_o, ymm_o, ymm_tmp); | ||||
|     // save ct and ht
 | ||||
|     vmovups(ptr[reg_ptr_ct + offset], ymm_ct); | ||||
|     vmovups(ptr[reg_ptr_ht + offset], ymm_o); | ||||
|     offset += sizeof(float) * YMM_FLOAT_BLOCK; | ||||
|   } | ||||
| 
 | ||||
|   if (use_peephole_) { | ||||
|     postCode(); | ||||
|   } else { | ||||
|     ret(); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #define DECLARE_LSTM_CREATOR(name)                                \ | ||||
|   class name##Creator : public JitCodeCreator<lstm_attr_t> {      \ | ||||
|    public:                                                        \ | ||||
|     /* TODO(TJ): enable more */                                   \ | ||||
|     bool UseMe(const lstm_attr_t& attr) const override {          \ | ||||
|       return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \ | ||||
|     }                                                             \ | ||||
|     size_t CodeSize(const lstm_attr_t& attr) const override {     \ | ||||
|       return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;          \ | ||||
|     }                                                             \ | ||||
|     std::unique_ptr<GenBase> CreateJitCode(                       \ | ||||
|         const lstm_attr_t& attr) const override {                 \ | ||||
|       return make_unique<name##JitCode>(attr, CodeSize(attr));    \ | ||||
|     }                                                             \ | ||||
|   } | ||||
| 
 | ||||
| DECLARE_LSTM_CREATOR(LSTMCtHt); | ||||
| DECLARE_LSTM_CREATOR(LSTMC1H1); | ||||
| 
 | ||||
| #undef DECLARE_LSTM_CREATOR | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace gen = paddle::operators::jit::gen; | ||||
| 
 | ||||
| REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator); | ||||
| REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator); | ||||
| @ -0,0 +1,118 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <string> | ||||
| #include "glog/logging.h" | ||||
| #include "paddle/fluid/operators/jit/gen/act.h" | ||||
| #include "paddle/fluid/operators/jit/gen/jitcode.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| namespace gen { | ||||
| 
 | ||||
| class LSTMJitCode : public VActFunc { | ||||
|  public: | ||||
|   explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr, | ||||
|                        size_t code_size, void* code_ptr = nullptr) | ||||
|       : VActFunc(code_size, code_ptr), | ||||
|         num_(attr.d), | ||||
|         compute_c1h1_(compute_c1h1), | ||||
|         use_peephole_(attr.use_peephole) { | ||||
|     auto typeExchange = [](KernelType type) -> gen::operand_type { | ||||
|       if (type == KernelType::kVSigmoid) { | ||||
|         return operand_type::SIGMOID; | ||||
|       } else if (type == KernelType::kVRelu) { | ||||
|         return operand_type::RELU; | ||||
|       } else if (type == KernelType::kVTanh) { | ||||
|         return operand_type::TANH; | ||||
|       } else if (type == KernelType::kVIdentity) { | ||||
|         return operand_type::IDENTITY; | ||||
|       } else { | ||||
|         LOG(FATAL) << "Do not support this jit::KernelType: " << type; | ||||
|       } | ||||
|       return operand_type::IDENTITY; | ||||
|     }; | ||||
|     act_gate_ = typeExchange(attr.act_gate); | ||||
|     act_cand_ = typeExchange(attr.act_cand); | ||||
|     act_cell_ = typeExchange(attr.act_cell); | ||||
| 
 | ||||
|     this->genCode(); | ||||
|   } | ||||
| 
 | ||||
|   const char* name() const override { | ||||
|     std::string base = "LSTMJitCode"; | ||||
|     if (use_peephole_) { | ||||
|       base += "_Peephole"; | ||||
|     } | ||||
|     if (compute_c1h1_) { | ||||
|       base += "_C1H1"; | ||||
|     } | ||||
|     auto AddTypeStr = [&](operand_type type) { | ||||
|       switch (type) { | ||||
|         case operand_type::RELU: | ||||
|           base += "_Relu"; | ||||
|           break; | ||||
|         case operand_type::EXP: | ||||
|           base += "_Exp"; | ||||
|           break; | ||||
|         case operand_type::SIGMOID: | ||||
|           base += "_Sigmoid"; | ||||
|           break; | ||||
|         case operand_type::TANH: | ||||
|           base += "_Tanh"; | ||||
|           break; | ||||
|         case operand_type::IDENTITY: | ||||
|           base += "_Identity"; | ||||
|           break; | ||||
|         default: | ||||
|           break; | ||||
|       } | ||||
|     }; | ||||
|     AddTypeStr(act_gate_); | ||||
|     AddTypeStr(act_cand_); | ||||
|     AddTypeStr(act_cell_); | ||||
|     return base.c_str(); | ||||
|   } | ||||
|   void genCode() override; | ||||
| 
 | ||||
|  protected: | ||||
|   int num_; | ||||
|   bool compute_c1h1_; | ||||
|   bool use_peephole_; | ||||
|   operand_type act_gate_; | ||||
|   operand_type act_cand_; | ||||
|   operand_type act_cell_; | ||||
|   reg64_t param1{abi_param1}; | ||||
| }; | ||||
| 
 | ||||
| #define DECLARE_LSTM_JITCODE(name, compute_c1h1)                      \ | ||||
|   class name##JitCode : public LSTMJitCode {                          \ | ||||
|    public:                                                            \ | ||||
|     explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \ | ||||
|                            void* code_ptr = nullptr)                  \ | ||||
|         : LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {}     \ | ||||
|   }; | ||||
| 
 | ||||
| DECLARE_LSTM_JITCODE(LSTMCtHt, false); | ||||
| DECLARE_LSTM_JITCODE(LSTMC1H1, true); | ||||
| 
 | ||||
| #undef DECLARE_LSTM_JITCODE | ||||
| 
 | ||||
| }  // namespace gen
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,43 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/gen_base.h" | ||||
| #include <fstream> | ||||
| #include <iostream> | ||||
| #include <sstream> | ||||
| 
 | ||||
| DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file"); | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| // refer do not need useme, it would be the last one.
 | ||||
| void GenBase::dumpCode(const unsigned char* code) const { | ||||
|   if (code) { | ||||
|     static int counter = 0; | ||||
|     std::ostringstream filename; | ||||
|     filename << "paddle_jitcode_" << name() << "." << counter << ".bin"; | ||||
|     counter++; | ||||
|     std::ofstream fout(filename.str(), std::ios::out); | ||||
|     if (fout.is_open()) { | ||||
|       fout.write(reinterpret_cast<const char*>(code), this->getSize()); | ||||
|       fout.close(); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,70 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <gflags/gflags.h> | ||||
| #include <memory>  // for unique_ptr | ||||
| #include "paddle/fluid/operators/jit/kernel_base.h" | ||||
| 
 | ||||
| DECLARE_bool(dump_jitcode); | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| class GenBase : public Kernel { | ||||
|  public: | ||||
|   virtual ~GenBase() = default; | ||||
|   virtual const char* name() const = 0; | ||||
|   virtual size_t getSize() const = 0; | ||||
|   virtual const unsigned char* getCodeInternal() = 0; | ||||
|   template <typename Func> | ||||
|   Func getCode() { | ||||
|     const unsigned char* code = this->getCodeInternal(); | ||||
|     if (FLAGS_dump_jitcode) { | ||||
|       this->dumpCode(code); | ||||
|     } | ||||
|     return reinterpret_cast<Func>(const_cast<unsigned char*>(code)); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   void dumpCode(const unsigned char* code) const; | ||||
| }; | ||||
| 
 | ||||
| // Creator is used to creat the jitcode and save in pool.
 | ||||
| // Every JitCode should have one creator.
 | ||||
| class GenCreator { | ||||
|  public: | ||||
|   virtual ~GenCreator() = default; | ||||
| }; | ||||
| 
 | ||||
| template <typename Attr> | ||||
| class JitCodeCreator : public GenCreator { | ||||
|  public: | ||||
|   virtual ~JitCodeCreator() = default; | ||||
| 
 | ||||
|   // condition when this jit code can be used.
 | ||||
|   virtual bool UseMe(const Attr& attr) const = 0; | ||||
| 
 | ||||
|   // estimate this code size
 | ||||
|   virtual size_t CodeSize(const Attr& attr) const = 0; | ||||
| 
 | ||||
|   // create this code
 | ||||
|   virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,76 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/helper.h" | ||||
| #include <algorithm>  // tolower | ||||
| #include "paddle/fluid/platform/enforce.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| #define ONE_CASE(key) \ | ||||
|   case key:           \ | ||||
|     return #key | ||||
| 
 | ||||
| const char* to_string(KernelType kt) { | ||||
|   switch (kt) { | ||||
|     ONE_CASE(kVMul); | ||||
|     ONE_CASE(kVAdd); | ||||
|     ONE_CASE(kVAddRelu); | ||||
|     ONE_CASE(kVSub); | ||||
|     ONE_CASE(kVScal); | ||||
|     ONE_CASE(kVAddBias); | ||||
|     ONE_CASE(kVRelu); | ||||
|     ONE_CASE(kVIdentity); | ||||
|     ONE_CASE(kVExp); | ||||
|     ONE_CASE(kVSigmoid); | ||||
|     ONE_CASE(kVTanh); | ||||
|     ONE_CASE(kLSTMCtHt); | ||||
|     ONE_CASE(kLSTMC1H1); | ||||
|     ONE_CASE(kGRUH1); | ||||
|     ONE_CASE(kGRUHtPart1); | ||||
|     ONE_CASE(kGRUHtPart2); | ||||
|     ONE_CASE(kCRFDecoding); | ||||
|     ONE_CASE(kLayerNorm); | ||||
|     ONE_CASE(kNCHW16CMulNC); | ||||
|     default: | ||||
|       PADDLE_THROW("Not support type: %d, or forget to add it.", kt); | ||||
|       return "NOT JITKernel"; | ||||
|   } | ||||
|   return nullptr; | ||||
| } | ||||
| #undef ONE_CASE | ||||
| 
 | ||||
| KernelType to_kerneltype(const std::string& act) { | ||||
|   std::string lower = act; | ||||
|   std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); | ||||
|   if (lower == "relu" || lower == "vrelu") { | ||||
|     return kVRelu; | ||||
|   } else if (lower == "identity" || lower == "videntity" || lower == "") { | ||||
|     return kVIdentity; | ||||
|   } else if (lower == "exp" || lower == "vexp") { | ||||
|     return kVExp; | ||||
|   } else if (lower == "sigmoid" || lower == "vsigmoid") { | ||||
|     return kVSigmoid; | ||||
|   } else if (lower == "tanh" || lower == "vtanh") { | ||||
|     return kVTanh; | ||||
|   } | ||||
|   PADDLE_THROW("Not support type: %s, or forget to add this case", act); | ||||
|   return kNone; | ||||
| } | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,140 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "paddle/fluid/operators/jit/gen_base.h" | ||||
| #include "paddle/fluid/operators/jit/kernel_base.h" | ||||
| #include "paddle/fluid/operators/jit/kernel_key.h" | ||||
| #include "paddle/fluid/operators/jit/kernel_pool.h" | ||||
| #include "paddle/fluid/platform/place.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| template <KernelType KT, typename KernelTuples, typename PlaceType> | ||||
| inline typename std::enable_if< | ||||
|     std::is_same<typename KernelTuples::data_type, float>::value && | ||||
|         std::is_same<PlaceType, platform::CPUPlace>::value, | ||||
|     typename KernelTuples::func_type>::type | ||||
| GetJitCode(const typename KernelTuples::attr_type& attr) { | ||||
|   using Func = typename KernelTuples::func_type; | ||||
|   using Attr = typename KernelTuples::attr_type; | ||||
|   size_t key = JitCodeKey<Attr>(attr); | ||||
|   auto& codes = JitCodePool<KT>().Instance(); | ||||
|   if (codes.Has(key)) { | ||||
|     return codes.AllKernels().at(key)->template getCode<Func>(); | ||||
|   } | ||||
| 
 | ||||
|   // creator is not related with attr, so can use KernelKey as key
 | ||||
|   KernelKey kkey(KT, PlaceType()); | ||||
|   // pool: (KernelKey(type, place), vector<GenCreatorPtr>)
 | ||||
|   auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); | ||||
|   auto iter = creator_map.find(kkey); | ||||
|   if (iter != creator_map.end()) { | ||||
|     auto& creators = iter->second; | ||||
|     for (auto& cur : creators) { | ||||
|       auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get()); | ||||
|       if (i && i->UseMe(attr)) { | ||||
|         auto p = i->CreateJitCode(attr); | ||||
|         if (p) { | ||||
|           auto f = p->template getCode<Func>(); | ||||
|           codes.Insert(key, std::move(p)); | ||||
|           return f; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return nullptr; | ||||
| } | ||||
| 
 | ||||
| template <KernelType KT, typename KernelTuples, typename PlaceType> | ||||
| inline typename std::enable_if< | ||||
|     !std::is_same<typename KernelTuples::data_type, float>::value || | ||||
|         !std::is_same<PlaceType, platform::CPUPlace>::value, | ||||
|     typename KernelTuples::func_type>::type | ||||
| GetJitCode(const typename KernelTuples::attr_type& attr) { | ||||
|   return nullptr; | ||||
| } | ||||
| 
 | ||||
| // Refer code do not related with attr, which is just for cast
 | ||||
| // Refer is always on CPUPlace
 | ||||
| template <KernelType KT, typename KernelTuples> | ||||
| inline typename KernelTuples::func_type GetRefer() { | ||||
|   auto& ref_pool = ReferKernelPool().Instance().AllKernels(); | ||||
|   KernelKey kkey(KT, platform::CPUPlace()); | ||||
|   auto ref_iter = ref_pool.find(kkey); | ||||
|   PADDLE_ENFORCE(ref_iter != ref_pool.end(), | ||||
|                  "Every Kernel should have reference function."); | ||||
|   auto& ref_impls = ref_iter->second; | ||||
|   for (auto& impl : ref_impls) { | ||||
|     auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get()); | ||||
|     if (i) { | ||||
|       return i->GetFunc(); | ||||
|     } | ||||
|   } | ||||
|   return nullptr; | ||||
| } | ||||
| 
 | ||||
| template <KernelType KT, typename KernelTuples, | ||||
|           typename PlaceType = platform::CPUPlace> | ||||
| typename KernelTuples::func_type Get( | ||||
|     const typename KernelTuples::attr_type& attr) { | ||||
|   auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr); | ||||
|   if (jitfunc) { | ||||
|     return jitfunc; | ||||
|   } | ||||
| 
 | ||||
|   // pool: (KernelKey(type, place), vector<KernelPtr>)
 | ||||
|   KernelKey kkey(KT, PlaceType()); | ||||
|   auto& pool = KernelPool().Instance().AllKernels(); | ||||
|   auto iter = pool.find(kkey); | ||||
|   if (iter != pool.end()) { | ||||
|     auto& impls = iter->second; | ||||
|     for (auto& impl : impls) { | ||||
|       auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get()); | ||||
|       if (i && i->UseMe(attr)) { | ||||
|         return i->GetFunc(); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // The last implementation should be reference function on CPUPlace.
 | ||||
|   return GetRefer<KT, KernelTuples>(); | ||||
| } | ||||
| 
 | ||||
| const char* to_string(KernelType kt); | ||||
| 
 | ||||
| KernelType to_kerneltype(const std::string& act); | ||||
| 
 | ||||
| inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) { | ||||
|   os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) | ||||
|      << "],act_cand[" << to_string(attr.act_cand) << "],act_cell[" | ||||
|      << to_string(attr.act_cell) << "],use_peephole[" | ||||
|      << (attr.use_peephole ? "True" : "False") << "]"; | ||||
|   return os; | ||||
| } | ||||
| inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { | ||||
|   os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) | ||||
|      << "],act_cand[" << to_string(attr.act_cand) << "]"; | ||||
|   return os; | ||||
| } | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,172 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| #include "paddle/fluid/operators/jit/macro.h" | ||||
| #include "paddle/fluid/platform/macros.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| typedef enum { | ||||
|   kNone = 0, | ||||
|   kVMul = 1, | ||||
|   kVAdd = 2, | ||||
|   kVAddRelu, | ||||
|   kVSub, | ||||
|   kVScal, | ||||
|   kVAddBias, | ||||
|   kVRelu, | ||||
|   kVIdentity, | ||||
|   kVExp, | ||||
|   kVSigmoid, | ||||
|   kVTanh, | ||||
|   kLSTMCtHt, | ||||
|   kLSTMC1H1, | ||||
|   kGRUH1, | ||||
|   kGRUHtPart1, | ||||
|   kGRUHtPart2, | ||||
|   kCRFDecoding, | ||||
|   kLayerNorm, | ||||
|   kNCHW16CMulNC, | ||||
| } KernelType; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct XYZNTuples { | ||||
|   typedef T data_type; | ||||
|   typedef int attr_type; | ||||
|   typedef void (*func_type)(const T*, const T*, T*, int); | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct AXYNTuples : public XYZNTuples<T> {}; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct XYNTuples { | ||||
|   typedef T data_type; | ||||
|   typedef int attr_type; | ||||
|   typedef void (*func_type)(const T*, T*, int); | ||||
| }; | ||||
| 
 | ||||
| typedef struct { | ||||
|   void* gates;  // gates: x_ch, x_ih, x_fh, x_oh
 | ||||
|   const void* ct_1; | ||||
|   void* ct; | ||||
|   void* ht; | ||||
|   /* weight_peephole and checked data are only used in peephole*/ | ||||
|   const void* wp{nullptr};  //  W_ic, W_fc, W_oc
 | ||||
|   void* checked{nullptr};   // size: 2 * d
 | ||||
| } lstm_t; | ||||
| 
 | ||||
| typedef struct { | ||||
|   void* gates;  // gates: {x_update, x_reset; x_state}
 | ||||
|   const void* ht_1; | ||||
|   void* ht; | ||||
| } gru_t; | ||||
| 
 | ||||
| struct rnn_attr_s { | ||||
|   int d; | ||||
|   KernelType act_gate, act_cand; | ||||
|   rnn_attr_s() = default; | ||||
|   explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand) | ||||
|       : d(_d), act_gate(_act_gate), act_cand(_act_cand) {} | ||||
| }; | ||||
| 
 | ||||
| struct lstm_attr_s : public rnn_attr_s { | ||||
|   bool use_peephole; | ||||
|   KernelType act_cell; | ||||
|   lstm_attr_s() = default; | ||||
|   explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand, | ||||
|                        KernelType _act_cell, bool _use_peephole = false) | ||||
|       : rnn_attr_s(_d, _act_gate, _act_cand), | ||||
|         use_peephole(_use_peephole), | ||||
|         act_cell(_act_cell) {} | ||||
| }; | ||||
| 
 | ||||
| typedef struct rnn_attr_s gru_attr_t; | ||||
| typedef struct lstm_attr_s lstm_attr_t; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct LSTMTuples { | ||||
|   typedef T data_type; | ||||
|   typedef lstm_attr_t attr_type; | ||||
|   typedef void (*func_type)(lstm_t*, const lstm_attr_t*); | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct GRUTuples { | ||||
|   typedef T data_type; | ||||
|   typedef gru_attr_t attr_type; | ||||
|   typedef void (*func_type)(gru_t*, const gru_attr_t*); | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct CRFDecodingTuples { | ||||
|   typedef T data_type; | ||||
|   typedef int attr_type; | ||||
|   typedef void (*func_type)(const int, const T*, const T*, T*, int*, int); | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct LayerNormTuples { | ||||
|   typedef T data_type; | ||||
|   typedef int attr_type; | ||||
|   typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int, | ||||
|                             const float, int); | ||||
| }; | ||||
| 
 | ||||
| // nChw16c = nChw16c .* NC
 | ||||
| template <typename T> | ||||
| struct NCHW16CMulNCTuples { | ||||
|   typedef T data_type; | ||||
|   typedef int attr_type; | ||||
|   typedef void (*func_type)(const T*, const T*, T*, int, int); | ||||
| }; | ||||
| 
 | ||||
| // Just for adding to kernel pool without template
 | ||||
| class Kernel { | ||||
|  public: | ||||
|   Kernel() = default; | ||||
|   virtual ~Kernel() = default; | ||||
|   DISABLE_COPY_AND_ASSIGN(Kernel); | ||||
| }; | ||||
| 
 | ||||
| template <typename KernelTuples> | ||||
| class KernelMore : public Kernel { | ||||
|  public: | ||||
|   using T = typename KernelTuples::data_type; | ||||
|   using Func = typename KernelTuples::func_type; | ||||
|   using Attr = typename KernelTuples::attr_type; | ||||
|   virtual Func GetFunc() const { return func; } | ||||
|   virtual bool UseMe(const Attr& attr) const = 0; | ||||
|   virtual const char* ImplType() const = 0; | ||||
| 
 | ||||
|  protected: | ||||
|   Func func{nullptr}; | ||||
| }; | ||||
| 
 | ||||
| template <typename KernelTuples> | ||||
| class ReferKernel : public KernelMore<KernelTuples> { | ||||
|  public: | ||||
|   // Refer code can always be used
 | ||||
|   bool UseMe(const typename KernelTuples::attr_type& attr) const override { | ||||
|     return true; | ||||
|   } | ||||
|   const char* ImplType() const override { return "Refer"; } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,47 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/jit/kernel_key.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| template <> | ||||
| size_t JitCodeKey<int>(const int& d) { | ||||
|   return d; | ||||
| } | ||||
| 
 | ||||
| constexpr int act_type_shift = 3;  // suppot 2^3 act types
 | ||||
| 
 | ||||
| template <> | ||||
| size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { | ||||
|   size_t key = attr.d; | ||||
|   int gate_key = static_cast<int>(attr.act_gate) << 1; | ||||
|   int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift); | ||||
|   int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2); | ||||
|   return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + | ||||
|          attr.use_peephole; | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { | ||||
|   size_t key = attr.d; | ||||
|   return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) + | ||||
|          (static_cast<int>(attr.act_cand) << act_type_shift); | ||||
| } | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,53 @@ | ||||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| #include "paddle/fluid/operators/jit/kernel_base.h" | ||||
| #include "paddle/fluid/platform/place.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace jit { | ||||
| 
 | ||||
| struct KernelKey { | ||||
|   struct Hash { | ||||
|     size_t operator()(const KernelKey& key) const { | ||||
|       int place = key.place_.which();               // less than 2^8
 | ||||
|       int type = static_cast<int>(key.type_) << 8;  // less than 2^(32-8)
 | ||||
|       std::hash<int> hasher; | ||||
|       return hasher(place + type); | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   KernelType type_; | ||||
|   platform::Place place_; | ||||
| 
 | ||||
|   KernelKey(KernelType type, platform::Place place) | ||||
|       : type_(type), place_(place) {} | ||||
|   size_t hash_key() const { return Hash()(*this); } | ||||
| 
 | ||||
|   bool operator==(const KernelKey& o) const { | ||||
|     return platform::places_are_same_class(place_, o.place_) && | ||||
|            type_ == o.type_; | ||||
|   } | ||||
|   bool operator!=(const KernelKey& o) const { return !(*this == o); } | ||||
| }; | ||||
| 
 | ||||
| // Every JitCode should have a method to get the key from attribution
 | ||||
| template <typename Attr> | ||||
| size_t JitCodeKey(const Attr& attr); | ||||
| 
 | ||||
| }  // namespace jit
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue