|
|
|
@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* in_t = ctx.Input<Tensor>("Input");
|
|
|
|
|
auto* out_t = ctx.Output<Tensor>("Out");
|
|
|
|
|
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
|
|
|
|
|
out_data[0] = in_t->numel();
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
auto out_data = out_t->mutable_data<int64_t>(place);
|
|
|
|
|
auto cpu_place = platform::CPUPlace();
|
|
|
|
|
if (place == cpu_place) {
|
|
|
|
|
out_data[0] = in_t->numel();
|
|
|
|
|
} else {
|
|
|
|
|
Tensor cpu_tensor;
|
|
|
|
|
auto cpu_data =
|
|
|
|
|
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
|
|
|
|
|
cpu_data[0] = in_t->numel();
|
|
|
|
|
TensorCopy(cpu_tensor, place, out_t);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|