Get BeamSearch Prob on C-API

guochaorong-patch-1
yuyang18 7 years ago
parent f92024470b
commit 08fa3983b0
No known key found for this signature in database
GPG Key ID: 6DFF29878217BE5F

@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args,
return kPD_NO_ERROR;
}
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
uint64_t ID,
paddle_matrix mat) {
if (args == nullptr || mat == nullptr) return kPD_NULLPTR;
auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat);
auto a = castArg(args);
if (ID >= a->args.size()) return kPD_OUT_OF_RANGE;
m->mat = a->args[ID].in;
return kPD_NO_ERROR;
}
paddle_error paddle_arguments_get_ids(paddle_arguments args,
uint64_t ID,
paddle_ivector ids) {

@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args,
uint64_t ID,
paddle_matrix mat);
/**
* @brief paddle_arguments_get_prob Get the prob matrix of beam search, which
* slot ID is `ID`
* @param [in] args arguments array
* @param [in] ID array index
* @param [out] mat matrix pointer
* @return paddle_error
*/
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
uint64_t ID,
paddle_matrix mat);
/**
* @brief PDArgsGetIds Get the integer vector of one argument in array, which
* index is `ID`.

Loading…
Cancel
Save