|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <iterator>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
@ -94,8 +95,11 @@ class RowwiseTransformIterator;
|
|
|
|
|
template <typename T, typename DeviceContext>
|
|
|
|
|
class MidWiseTransformIterator;
|
|
|
|
|
|
|
|
|
|
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
|
|
|
|
|
template <typename T>
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
|
|
|
|
|
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
|
|
|
|
|
T *, T &> {
|
|
|
|
|
public:
|
|
|
|
|
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
|
|
|
|
|
|
|
|
|
@ -126,7 +130,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
|
|
|
|
|
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
|
|
|
|
|
T *, T &> {
|
|
|
|
|
public:
|
|
|
|
|
MidWiseTransformIterator(const T *ptr, int n, int post)
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast(
|
|
|
|
|
const framework::Tensor &dout, int axis, framework::Tensor *dx,
|
|
|
|
|
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
|
|
|
|
|
size_t N = static_cast<size_t>(framework::product(x_dim));
|
|
|
|
|
#if !defined(_WIN32)
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), N);
|
|
|
|
|
#else
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.device_context<DeviceContext>(), N);
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
|
|
|
|
|
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
|
|
|
|
|
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename DeviceContext, typename T,
|
|
|
|
|
typename OutType = T>
|
|
|
|
|
|
|
|
|
|
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x,
|
|
|
|
|
const framework::Tensor *y, int axis, Functor func,
|
|
|
|
|
framework::Tensor *z) {
|
|
|
|
|
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
|
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), func);
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims_untrimed = y->dims();
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
|
|
|
|
|