|
|
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
@ -20,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -60,8 +64,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type());
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP16) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
|
|
|
|
|
"float16 can only be used when CUDNN is used");
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"float16 can only be used on GPU place");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
@ -70,6 +74,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
library_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|