|
|
|
@ -98,7 +98,7 @@ class TransposeCUDAKernel : public framework::OpKernel {
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
auto* in = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
auto axis = context.GetAttr<std::vector<int>>("axis");
|
|
|
|
|
auto axis = context.Attr<std::vector<int>>("axis");
|
|
|
|
|
TransposeCUDA<T>(context, *in, *out, axis);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -111,7 +111,7 @@ class TransposeGradCUDAKernel : public framework::OpKernel {
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto axis_temp = context.GetAttr<std::vector<int>>("axis");
|
|
|
|
|
auto axis_temp = context.Attr<std::vector<int>>("axis");
|
|
|
|
|
|
|
|
|
|
std::vector<int> axis(axis_temp);
|
|
|
|
|
|
|
|
|
|