|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/platform/device_context.h"
|
|
|
|
|
|
|
|
|
@ -70,13 +71,13 @@ struct Add {
|
|
|
|
|
const framework::SelectedRows& input1,
|
|
|
|
|
const framework::SelectedRows& input2,
|
|
|
|
|
framework::SelectedRows* out) {
|
|
|
|
|
out->set_rows(input1->rows());
|
|
|
|
|
out->set_height(input1->height());
|
|
|
|
|
out->mutable_value()->mutable_data<T>(input1->value().dims(),
|
|
|
|
|
out->set_rows(input1.rows());
|
|
|
|
|
out->set_height(input1.height());
|
|
|
|
|
out->mutable_value()->mutable_data<T>(input1.value().dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
|
|
|
|
|
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -87,13 +88,13 @@ struct Mul {
|
|
|
|
|
const framework::SelectedRows& input1,
|
|
|
|
|
const framework::SelectedRows& input2,
|
|
|
|
|
framework::SelectedRows* out) {
|
|
|
|
|
out->set_rows(input1->rows());
|
|
|
|
|
out->set_height(input1->height());
|
|
|
|
|
out->mutable_value()->mutable_data<T>(input1->value().dims(),
|
|
|
|
|
out->set_rows(input1.rows());
|
|
|
|
|
out->set_height(input1.height());
|
|
|
|
|
out->mutable_value()->mutable_data<T>(input1.value().dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
|
|
|
|
|
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
|
|
|
|
|
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
|
|
|
|
|
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|