fix jitcodekey and refine test

test=develop
revert-16045-imperative_remove_desc
tensor-tang 6 years ago committed by ceci3
parent ce4cc482a4
commit 0eefad0a2d

@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(act_type_convert(attr.act_cand) << act_type_shift);
}
template <>

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save