|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/imperative/amp_auto_cast.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
@ -35,14 +36,29 @@ AmpOperators& AmpOperators::Instance() {
|
|
|
|
|
return instance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetAllowOps() {
|
|
|
|
|
std::shared_ptr<std::unordered_set<std::string>>
|
|
|
|
|
AmpOperators::GetMutableAllowOps() {
|
|
|
|
|
return allow_ops_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<std::unordered_set<std::string>> AmpOperators::GetBlockOps() {
|
|
|
|
|
std::shared_ptr<std::unordered_set<std::string>>
|
|
|
|
|
AmpOperators::GetMutableBlockOps() {
|
|
|
|
|
return block_ops_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
|
|
|
|
|
os << "allow ops: ";
|
|
|
|
|
auto allow_ops = ops.GetMutableAllowOps();
|
|
|
|
|
std::copy((*allow_ops).begin(), (*allow_ops).end(),
|
|
|
|
|
std::ostream_iterator<std::string>(os, " "));
|
|
|
|
|
os << "; ";
|
|
|
|
|
os << "block ops: ";
|
|
|
|
|
auto block_ops = ops.GetMutableBlockOps();
|
|
|
|
|
std::copy((*block_ops).begin(), (*block_ops).end(),
|
|
|
|
|
std::ostream_iterator<std::string>(os, " "));
|
|
|
|
|
return os;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::string GetDtypeStr(
|
|
|
|
|
const std::shared_ptr<imperative::VarBase>& var) {
|
|
|
|
|
return framework::DataTypeToString(var->DataType());
|
|
|
|
@ -115,51 +131,50 @@ static inline framework::proto::VarType::Type GetPromoteType(
|
|
|
|
|
|
|
|
|
|
NameVarBaseMap AutoCastInputs(const std::string& op_type,
|
|
|
|
|
const NameVarBaseMap& ins) {
|
|
|
|
|
NameVarBaseMap new_ins = {};
|
|
|
|
|
if (AmpOperators::Instance().GetAllowOps()->count(op_type)) {
|
|
|
|
|
for (const auto& pair : ins) {
|
|
|
|
|
NameVarBaseMap new_ins(ins);
|
|
|
|
|
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
|
|
|
|
|
for (auto& pair : new_ins) {
|
|
|
|
|
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
|
|
|
|
|
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
|
|
|
|
|
pair.first != "X") {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
|
|
|
|
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
|
|
|
|
|
for (const auto& var : pair.second) {
|
|
|
|
|
auto new_var = CastToFP16(var);
|
|
|
|
|
new_ins[pair.first].emplace_back(new_var);
|
|
|
|
|
for (auto& var : pair.second) {
|
|
|
|
|
var = CastToFP16(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return new_ins;
|
|
|
|
|
} else if (AmpOperators::Instance().GetBlockOps()->count(op_type)) {
|
|
|
|
|
for (const auto& pair : ins) {
|
|
|
|
|
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
|
|
|
|
|
for (auto& pair : new_ins) {
|
|
|
|
|
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
|
|
|
|
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
|
|
|
|
|
for (const auto& var : pair.second) {
|
|
|
|
|
auto new_var = CastToFP32(var);
|
|
|
|
|
new_ins[pair.first].emplace_back(new_var);
|
|
|
|
|
for (auto& var : pair.second) {
|
|
|
|
|
var = CastToFP32(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return new_ins;
|
|
|
|
|
} else {
|
|
|
|
|
auto dst_type = GetPromoteType(ins);
|
|
|
|
|
|
|
|
|
|
for (const auto& pair : ins) {
|
|
|
|
|
for (auto& pair : new_ins) {
|
|
|
|
|
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
|
|
|
|
|
if ((op_type == "batch_norm" || op_type == "layer_norm") &&
|
|
|
|
|
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
|
|
|
|
|
<< GetDtypeStr(*pair.second.cbegin()) << " to "
|
|
|
|
|
<< framework::DataTypeToString(dst_type);
|
|
|
|
|
for (const auto& var : pair.second) {
|
|
|
|
|
// NOTE(zhiqiu): Conv + BN always occur together, we needn't
|
|
|
|
|
// cast X of batch_norm to FP32, which is produced by conv as FP16 type.
|
|
|
|
|
if (op_type == "batch_norm" && pair.first == "X" &&
|
|
|
|
|
dst_type == framework::proto::VarType::FP32) {
|
|
|
|
|
new_ins[pair.first].emplace_back(var);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto new_var = dst_type == framework::proto::VarType::FP32
|
|
|
|
|
? CastToFP32(var)
|
|
|
|
|
: CastToFP16(var);
|
|
|
|
|
new_ins[pair.first].emplace_back(new_var);
|
|
|
|
|
for (auto& var : pair.second) {
|
|
|
|
|
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
|
|
|
|
|
: CastToFP16(var));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return new_ins;
|
|
|
|
|
}
|
|
|
|
|
return ins;
|
|
|
|
|
return new_ins;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace imperative
|
|
|
|
|