Merge pull request #15563 from tensor-tang/jit/softmax
refine softmax kernelrevert-15296-async_double_buffered_py_reader
commit
c7449227e8
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/jit/gen/hopv.h"
|
||||||
|
#include "paddle/fluid/operators/jit/registry.h"
|
||||||
|
#include "paddle/fluid/platform/cpu_info.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace jit {
|
||||||
|
namespace gen {
|
||||||
|
|
||||||
|
void HOPVJitCode::genCode() {
|
||||||
|
const int num_blocks = num_ / YMM_FLOAT_BLOCK;
|
||||||
|
int offset = 0;
|
||||||
|
|
||||||
|
if (num_blocks > 0) {
|
||||||
|
// load one firstly
|
||||||
|
vmovups(ymm_tmp, ptr[param_src]);
|
||||||
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||||
|
for (int i = 1; i < num_blocks; ++i) {
|
||||||
|
vmovups(ymm_src, ptr[param_src + offset]);
|
||||||
|
process(ymm_tmp, ymm_src, ymm_tmp);
|
||||||
|
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||||
|
}
|
||||||
|
vextractf128(xmm_dst, ymm_tmp, 1);
|
||||||
|
process(xmm_dst, xmm_dst, xmm_tmp);
|
||||||
|
} else {
|
||||||
|
if (type_ == operand_type::MAX) {
|
||||||
|
vbroadcastss(ymm_dst, ptr[param_src]);
|
||||||
|
} else if (type_ == operand_type::ADD) {
|
||||||
|
vxorps(ymm_dst, ymm_dst, ymm_dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int rest = num_ % YMM_FLOAT_BLOCK;
|
||||||
|
if (rest >= 4) {
|
||||||
|
vmovups(xmm_src, ptr[param_src + offset]);
|
||||||
|
offset += sizeof(float) * 4;
|
||||||
|
rest -= 4;
|
||||||
|
process(xmm_dst, xmm_dst, xmm_src);
|
||||||
|
}
|
||||||
|
|
||||||
|
vpermilps(xmm_tmp, xmm_dst, 16 + 8 + 3);
|
||||||
|
process(xmm_dst, xmm_dst, xmm_tmp);
|
||||||
|
|
||||||
|
if (rest >= 2) {
|
||||||
|
vmovq(xmm_src, ptr[param_src + offset]);
|
||||||
|
offset += sizeof(float) * 2;
|
||||||
|
rest -= 2;
|
||||||
|
process(xmm_dst, xmm_dst, xmm_src);
|
||||||
|
}
|
||||||
|
|
||||||
|
vpermilps(xmm_tmp, xmm_dst, 1);
|
||||||
|
process(xmm_dst, xmm_dst, xmm_tmp);
|
||||||
|
|
||||||
|
if (rest >= 1) {
|
||||||
|
vmovss(xmm_src, ptr[param_src + offset]);
|
||||||
|
process(xmm_dst, xmm_dst, xmm_src);
|
||||||
|
}
|
||||||
|
vmovss(ptr[param_dst], xmm_dst);
|
||||||
|
ret();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DECLARE_HOP_CREATOR(name) \
|
||||||
|
class name##Creator : public JitCodeCreator<int> { \
|
||||||
|
public: \
|
||||||
|
bool UseMe(const int& attr) const override { \
|
||||||
|
return platform::MayIUse(platform::avx); \
|
||||||
|
} \
|
||||||
|
size_t CodeSize(const int& d) const override { \
|
||||||
|
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
|
||||||
|
} \
|
||||||
|
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
|
||||||
|
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_HOP_CREATOR(HMax);
|
||||||
|
DECLARE_HOP_CREATOR(HSum);
|
||||||
|
|
||||||
|
#undef DECLARE_HOP_CREATOR
|
||||||
|
|
||||||
|
} // namespace gen
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace gen = paddle::operators::jit::gen;
|
||||||
|
|
||||||
|
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
|
||||||
|
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);
|
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace jit {
|
||||||
|
namespace gen {
|
||||||
|
|
||||||
|
// horizontal operand vector
|
||||||
|
class HOPVJitCode : public JitCode {
|
||||||
|
public:
|
||||||
|
explicit HOPVJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
|
||||||
|
void* code_ptr = nullptr)
|
||||||
|
: JitCode(code_size, code_ptr), num_(d), type_(type) {
|
||||||
|
if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) {
|
||||||
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
||||||
|
}
|
||||||
|
this->genCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const char* name() const {
|
||||||
|
std::string base = "VXXJitCode";
|
||||||
|
if (type_ == operand_type::MAX) {
|
||||||
|
base += "_MAX";
|
||||||
|
} else {
|
||||||
|
base += "_SUM";
|
||||||
|
}
|
||||||
|
return base.c_str();
|
||||||
|
}
|
||||||
|
void genCode() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template <typename JMM>
|
||||||
|
void process(JMM& dst, JMM& src1, JMM& src2) { // NOLINT
|
||||||
|
if (type_ == operand_type::MAX) {
|
||||||
|
vmaxps(dst, src1, src2);
|
||||||
|
} else if (type_ == operand_type::ADD) {
|
||||||
|
vaddps(dst, src1, src2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_;
|
||||||
|
operand_type type_;
|
||||||
|
reg64_t param_src{abi_param1};
|
||||||
|
reg64_t param_dst{abi_param2};
|
||||||
|
reg64_t param_attr{abi_param3};
|
||||||
|
|
||||||
|
ymm_t ymm_tmp = ymm_t(0);
|
||||||
|
ymm_t ymm_src = ymm_t(1);
|
||||||
|
ymm_t ymm_dst = ymm_t(2);
|
||||||
|
|
||||||
|
xmm_t xmm_tmp = xmm_t(0);
|
||||||
|
xmm_t xmm_src = xmm_t(1);
|
||||||
|
xmm_t xmm_dst = xmm_t(2);
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DECLARE_HOP_JITCODE(name, op_type) \
|
||||||
|
class name##JitCode : public HOPVJitCode { \
|
||||||
|
public: \
|
||||||
|
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
|
||||||
|
: HOPVJitCode(d, op_type, code_size, code_ptr) {} \
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_HOP_JITCODE(HMax, operand_type::MAX);
|
||||||
|
DECLARE_HOP_JITCODE(HSum, operand_type::ADD);
|
||||||
|
|
||||||
|
#undef DECLARE_HOP_JITCODE
|
||||||
|
|
||||||
|
} // namespace gen
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue