optimize compilation time of argmin/argmax op (#29595)

* Using VisitDataTypeTiny and put CastOP after ReduceOP, test=develop

* remove changes of reduce_op.h, test=develop
revert-31562-mean
wuhuanzhou 5 years ago committed by GitHub
parent c4eb5d0378
commit e7ac74c85b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return;
}
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
}

@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return;
}
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
}

Loading…
Cancel
Save