del_some_in_makelist
typhoonzero 7 years ago
parent d48a0e4eae
commit 74b122889c

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
@ -70,13 +71,13 @@ struct Add {
const framework::SelectedRows& input1, const framework::SelectedRows& input1,
const framework::SelectedRows& input2, const framework::SelectedRows& input2,
framework::SelectedRows* out) { framework::SelectedRows* out) {
out->set_rows(input1->rows()); out->set_rows(input1.rows());
out->set_height(input1->height()); out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1->value().dims(), out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace()); context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value())); auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value()); auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value()); auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2; e_out.device(*context.eigen_device()) = e_in1 + e_in2;
} }
}; };
@ -87,13 +88,13 @@ struct Mul {
const framework::SelectedRows& input1, const framework::SelectedRows& input1,
const framework::SelectedRows& input2, const framework::SelectedRows& input2,
framework::SelectedRows* out) { framework::SelectedRows* out) {
out->set_rows(input1->rows()); out->set_rows(input1.rows());
out->set_height(input1->height()); out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1->value().dims(), out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace()); context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value())); auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value()); auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value()); auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2; e_out.device(*context.eigen_device()) = e_in1 * e_in2;
} }
}; };

Loading…
Cancel
Save