@ -660,19 +660,26 @@ template <typename DeviceContext, typename T>
class TransposeGPUKernel : public framework::OpKernel<T> {
class TransposeGPUKernel : public framework::OpKernel<T> {
public:
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* x = context.InputVar("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out = context.OutputVar("Out");
out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
return;
}
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret = TransposeSimple<T>::run(dev_ctx, *x, axis, out);
auto ret = TransposeSimple<T>::run(dev_ctx, *x_tensor , axis, out_tensor );
if (!ret) {
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x, out, axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor,
axis);
}
}
}
}
};
};
@ -680,14 +687,19 @@ template <typename DeviceContext, typename T>
class TransposeGradGPUKernel : public framework::OpKernel<T> {
class TransposeGradGPUKernel : public framework::OpKernel<T> {
public:
public:
void Compute(const framework::ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
auto* x_grad =
if (!x_grad) {
context.Output<framework::Tensor>(framework::GradVarName("X"));
return;
if (!x_grad) return;
}
x_grad->mutable_data<T>(context.GetPlace());
const framework::Tensor* out_grad_tensor =
if (x_grad->numel() == 0) {
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
return;
}
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
@ -699,11 +711,11 @@ class TransposeGradGPUKernel : public framework::OpKernel<T> {
int ndims = axis.size();
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret =
auto ret = TransposeSimple<T>::run(dev_ctx, *out_grad_tensor, reversed_axis,
TransposeSimple<T>::run(dev_ctx, *out_grad, reversed_axis, x_grad );
x_grad_tensor );
if (!ret) {
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad, x_grad ,
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor ,
reversed_axis);
x_grad_tensor, reversed_axis);
}
}
}
}
};
};