|
|
@ -35,7 +35,7 @@ protected:
|
|
|
|
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
|
|
|
|
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
|
|
|
|
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
|
|
|
|
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
|
|
|
|
"label's dimension must be 1.");
|
|
|
|
"label's dimension must be 1.");
|
|
|
|
outputs[0]->set_dims(framework::make_ddim({inputs[0]->dims()[0]}));
|
|
|
|
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|