|
|
|
@ -232,18 +232,18 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
std::is_floating_point<T>::value &&
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
|
|
|
|
|
T* out) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
blas.AXPY(data_len, 1., in, out);
|
|
|
|
|
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
|
|
|
|
|
size_t data_len, const T* in, T* out) {
|
|
|
|
|
// auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
blas->AXPY(data_len, 1., in, out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
!std::is_floating_point<T>::value &&
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
|
|
|
|
|
T* out) {
|
|
|
|
|
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
|
|
|
|
|
size_t data_len, const T* in, T* out) {
|
|
|
|
|
for (int64_t i = 0; i < data_len; i++) {
|
|
|
|
|
out[i] += in[i];
|
|
|
|
|
}
|
|
|
|
@ -305,10 +305,11 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = rows_to_id[input_rows[i]];
|
|
|
|
|
elementwise_add<platform::CPUDeviceContext, T>(
|
|
|
|
|
context, static_cast<size_t>(input_width),
|
|
|
|
|
context, &blas, static_cast<size_t>(input_width),
|
|
|
|
|
&input_data[i * input_width], &out_data[out_i * input_width]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|