|
|
|
@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
|
|
|
|
|
PADDLE_ENFORCE(attr->type == SeqPoolType::sum, "Only support sum yet");
|
|
|
|
|
for (int w = 0; w < attr->w; ++w) {
|
|
|
|
|
const T* src = x + w;
|
|
|
|
|
T* dst = y + w;
|
|
|
|
|
*dst = static_cast<T>(0);
|
|
|
|
|
for (int h = 0; h < attr->h; ++h) {
|
|
|
|
|
*dst = *dst + *src;
|
|
|
|
|
src += attr->w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define DECLARE_REFER_KERNEL(name, tuples) \
|
|
|
|
|
template <typename T> \
|
|
|
|
|
class name##Kernel : public ReferKernel<tuples<T>> { \
|
|
|
|
@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
|
|
|
|
|
|
|
|
|
|
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
|
|
|
|
|
|
|
|
|
|
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
|
|
|
|
|
|
|
|
|
|
#undef DECLARE_REFER_KERNEL
|
|
|
|
|
|
|
|
|
|
} // namespace refer
|
|
|
|
|