|
|
|
@ -81,10 +81,6 @@ class TensorAddToFunctor : public boost::static_visitor<> {
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
template <int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<float, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
void AddTo(Variable* src, Variable* dst, platform::Place place) {
|
|
|
|
|
framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
|
|
|
|
|
framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();
|
|
|
|
@ -99,18 +95,10 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) {
|
|
|
|
|
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
|
|
|
|
|
src_tensor->numel());
|
|
|
|
|
|
|
|
|
|
auto result = EigenVector<>::Flatten(*dst_tensor);
|
|
|
|
|
auto in_0_e = EigenVector<>::Flatten(*dst_tensor);
|
|
|
|
|
auto in_1_e = EigenVector<>::Flatten(*src_tensor);
|
|
|
|
|
platform::DeviceContext* dev_ctx =
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(place);
|
|
|
|
|
platform::CPUDeviceContext* x =
|
|
|
|
|
reinterpret_cast<platform::CPUDeviceContext*>(dev_ctx);
|
|
|
|
|
result.device(*x->eigen_device()) = in_0_e + in_1_e;
|
|
|
|
|
// detail::TensorAddToFunctor<float> func(
|
|
|
|
|
// src_tensor->numel(), src_tensor->data<float>(),
|
|
|
|
|
// dst_tensor->mutable_data<float>(place));
|
|
|
|
|
// boost::apply_visitor(func, place);
|
|
|
|
|
detail::TensorAddToFunctor<float> func(
|
|
|
|
|
src_tensor->numel(), src_tensor->data<float>(),
|
|
|
|
|
dst_tensor->mutable_data<float>(place));
|
|
|
|
|
boost::apply_visitor(func, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class Autograd {
|
|
|
|
@ -134,7 +122,7 @@ class Autograd {
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> input_grads =
|
|
|
|
|
ready_op->ApplyGrad();
|
|
|
|
|
|
|
|
|
|
for (auto it : input_grads) {
|
|
|
|
|
for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
|
|
|
|
|
const std::vector<VarBase*>& ingrads = it.second;
|
|
|
|
|
for (int64_t i = ingrads.size() - 1; i >= 0; --i) {
|
|
|
|
|
if (!ingrads[i]) continue;
|
|
|
|
|