|
|
|
@ -22,7 +22,8 @@ template <typename T>
|
|
|
|
|
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
static void compute(const platform::CUDADeviceContext &context,
|
|
|
|
|
GRUMetaValue<T> value, int frame_size, int batch_size,
|
|
|
|
|
ActivationType active_node, ActivationType active_gate) {
|
|
|
|
|
const detail::ActivationType active_node,
|
|
|
|
|
const detail::ActivationType active_gate) {
|
|
|
|
|
auto stream = context.stream();
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
@ -89,7 +90,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
static void compute(const platform::CUDADeviceContext &context,
|
|
|
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
ActivationType active_node, ActivationType active_gate) {
|
|
|
|
|
const detail::ActivationType active_node,
|
|
|
|
|
const detail::ActivationType active_gate) {
|
|
|
|
|
auto stream = context.stream();
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
|