|
|
|
@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <iterator>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
@ -92,8 +94,11 @@ class RowwiseTransformIterator;
|
|
|
|
|
template <typename T, typename DeviceContext>
|
|
|
|
|
class MidWiseTransformIterator;
|
|
|
|
|
|
|
|
|
|
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
|
|
|
|
|
template <typename T>
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
|
|
|
|
|
: public std::iterator<std::random_access_iterator_tag, typename T,
|
|
|
|
|
std::ptrdiff_t, typename T*, typename T&> {
|
|
|
|
|
public:
|
|
|
|
|
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
|
|
|
|
|
|
|
|
|
@ -124,7 +129,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
|
|
|
|
|
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
|
|
|
|
|
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
|
|
|
|
|
T*, T&> {
|
|
|
|
|
public:
|
|
|
|
|
MidWiseTransformIterator(const T* ptr, int n, int post)
|
|
|
|
|
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
|
|
|
|
@ -473,8 +480,13 @@ void ElemwiseGradComputeNoBroadcast(
|
|
|
|
|
const framework::Tensor& dout, int axis, framework::Tensor* dx,
|
|
|
|
|
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
|
|
|
|
|
size_t N = static_cast<size_t>(framework::product(x_dim));
|
|
|
|
|
#if !defined(_WIN32)
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), N);
|
|
|
|
|
#else
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.device_context<DeviceContext>(), N);
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
|
|
|
|
|
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
|
|
|
|
|
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
@ -631,9 +643,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y, int axis, Functor func,
|
|
|
|
|
framework::Tensor* z) {
|
|
|
|
|
#if !defined(_WIN32)
|
|
|
|
|
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
|
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), func);
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
|
|
|
|
|
x, y, z, ctx.device_context<DeviceContext>(), func);
|
|
|
|
|
#endif // !_WIN32
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims_untrimed = y->dims();
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
|
|
|
|
|