fix shape in floats

fix_recordio_link
seiriosPlus 6 years ago
parent 318ba99124
commit 06de824ba8

@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "The input SelectedRows."); AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddComment(R"DOC( AddComment(R"DOC(
Split a SelectedRows with a specified rows section. Split a SelectedRows with a specified rows section.

@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int>& abs_sections) { static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) { for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) { if (row < abs_sections[i]) {
return i - 1; return i - 1;
@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
return abs_sections.size() - 1; return abs_sections.size() - 1;
} }
static std::vector<int> ToAbsoluteSection( static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) { const std::vector<int64_t>& height_sections) {
std::vector<int> abs_sections; std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size()); abs_sections.resize(height_sections.size());
abs_sections[0] = 0; abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) { for (size_t i = 1; i < height_sections.size(); ++i) {
@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::SelectedRows>("X"); auto* x = ctx.Input<framework::SelectedRows>("X");
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto abs_sections = ToAbsoluteSection(height_sections); auto abs_sections = ToAbsoluteSection(height_sections);

@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = context.Attr<std::vector<int>>("shape"); auto shape = context.Attr<std::vector<int64_t>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} else { } else {

@ -57,6 +57,18 @@ struct variant_caster<V<Ts...>> {
auto caster = make_caster<T>(); auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) { if (!load_success_ && caster.load(src, convert)) {
load_success_ = true; load_success_ = true;
if (std::is_same<T, std::vector<float>>::value) {
auto caster_ints = make_caster<std::vector<int64_t>>();
if (caster_ints.load(src, convert)) {
VLOG(4) << "This value are floats and int64_ts satisfy "
"simultaneously, will set it's type to "
"std::vector<int64_t>";
value = cast_op<std::vector<int64_t>>(caster_ints);
return true;
}
}
value = cast_op<T>(caster); value = cast_op<T>(caster);
return true; return true;
} }

Loading…
Cancel
Save