|
|
|
@ -88,57 +88,6 @@ struct MergeAdd {
|
|
|
|
|
framework::SelectedRows* output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
struct Add {
|
|
|
|
|
framework::SelectedRows operator()(const DeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input1,
|
|
|
|
|
const framework::SelectedRows& input2) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
out.set_rows(input1.rows());
|
|
|
|
|
out.set_height(input1.height());
|
|
|
|
|
out.mutable_value()->mutable_data<T>(input1.value().dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
|
|
|
|
|
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
struct Mul {
|
|
|
|
|
// multiply two SelectedRows
|
|
|
|
|
framework::SelectedRows operator()(const DeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input1,
|
|
|
|
|
const framework::SelectedRows& input2) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
out.set_rows(input1.rows());
|
|
|
|
|
out.set_height(input1.height());
|
|
|
|
|
out.mutable_value()->mutable_data<T>(input1.value().dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
|
|
|
|
|
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
// multiply scalar to SelectedRows
|
|
|
|
|
framework::SelectedRows operator()(const DeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input1,
|
|
|
|
|
const T input2) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
out.set_rows(input1.rows());
|
|
|
|
|
out.set_height(input1.height());
|
|
|
|
|
out.mutable_value()->mutable_data<T>(input1.value().dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
|
|
|
|
|
e_out.device(*context.eigen_device()) = input2 * e_in1;
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
|
|
|
|
|
|
|
|
|
|
// out = seleted_rows_in / tensor
|
|
|
|
|