|
|
|
@ -44,25 +44,76 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
LaunchInit(inputs, workspace, outputs);
|
|
|
|
|
LaunchFirstHalf(inputs, workspace, outputs, stream_ptr);
|
|
|
|
|
LaunchSecondHalf(inputs, workspace, outputs, stream_ptr);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
InitResource();
|
|
|
|
|
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
if (probs_shape.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
probs_dims_[0] = probs_shape[0];
|
|
|
|
|
probs_dims_[1] = probs_shape[1];
|
|
|
|
|
probs_dims_[2] = probs_shape[2];
|
|
|
|
|
auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
|
|
|
|
if (labels_dims.size() != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
if (indice_dims.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "labels indice dims: " << indice_dims.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
label_size_ = sizeof(int);
|
|
|
|
|
for (auto i : labels_dims) {
|
|
|
|
|
label_size_ *= i;
|
|
|
|
|
}
|
|
|
|
|
label_indice_size_ = sizeof(int64_t);
|
|
|
|
|
for (auto i : indice_dims) {
|
|
|
|
|
label_indice_size_ *= i;
|
|
|
|
|
}
|
|
|
|
|
auto squence_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
|
|
|
|
squence_lengths_size_ = squence_length_dims[0] * sizeof(int);
|
|
|
|
|
preprocess_collapse_repeated_ = GetAttr<bool>(kernel_node, "preprocess_collapse_repeated");
|
|
|
|
|
ctc_merge_repeated_ = GetAttr<bool>(kernel_node, "ctc_merge_repeated");
|
|
|
|
|
ignore_longer_outputs_than_inputs_ = GetAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs");
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void LaunchInit(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs) {
|
|
|
|
|
probs = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
label_indices = GetDeviceAddress<int64_t>(inputs, 1);
|
|
|
|
|
label_values = GetDeviceAddress<int>(inputs, 2);
|
|
|
|
|
sequence_length = GetDeviceAddress<int>(inputs, 3);
|
|
|
|
|
costs = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
grads = GetDeviceAddress<T>(outputs, 1);
|
|
|
|
|
softmax_probs = GetDeviceAddress<T>(workspace, 0);
|
|
|
|
|
cum_labels_length = GetDeviceAddress<int>(workspace, 1);
|
|
|
|
|
label_squence_length = GetDeviceAddress<int>(workspace, 2);
|
|
|
|
|
label_value_sp = GetDeviceAddress<int>(workspace, 3);
|
|
|
|
|
label_value_pcr = GetDeviceAddress<int>(workspace, 4);
|
|
|
|
|
prob_num = GetDeviceAddress<T>(workspace, 5);
|
|
|
|
|
precum_labels_length = GetDeviceAddress<int>(workspace, 6);
|
|
|
|
|
max_labels_length = GetDeviceAddress<int>(workspace, 7);
|
|
|
|
|
numclass = SizeToInt(probs_dims_[2]);
|
|
|
|
|
batch = SizeToInt(probs_dims_[1]);
|
|
|
|
|
max_time = SizeToInt(probs_dims_[0]);
|
|
|
|
|
max_sequence = 0;
|
|
|
|
|
max_labels_length_host = 0;
|
|
|
|
|
batch_label = 0;
|
|
|
|
|
label_value_with_blank = nullptr;
|
|
|
|
|
log_alpha_b = nullptr;
|
|
|
|
|
log_beta_b = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LaunchFirstHalf(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
|
|
|
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
|
|
|
|
const T *probs = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
const int64_t *label_indices = GetDeviceAddress<int64_t>(inputs, 1);
|
|
|
|
|
const int *label_values = GetDeviceAddress<int>(inputs, 2);
|
|
|
|
|
const int *sequence_length = GetDeviceAddress<int>(inputs, 3);
|
|
|
|
|
T *costs = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
T *grads = GetDeviceAddress<T>(outputs, 1);
|
|
|
|
|
T *softmax_probs = GetDeviceAddress<T>(workspace, 0);
|
|
|
|
|
int *cum_labels_length = GetDeviceAddress<int>(workspace, 1);
|
|
|
|
|
int *label_squence_length = GetDeviceAddress<int>(workspace, 2);
|
|
|
|
|
int *label_value_sp = GetDeviceAddress<int>(workspace, 3);
|
|
|
|
|
int *label_value_pcr = GetDeviceAddress<int>(workspace, 4);
|
|
|
|
|
T *prob_num = GetDeviceAddress<T>(workspace, 5);
|
|
|
|
|
int *precum_labels_length = GetDeviceAddress<int>(workspace, 6);
|
|
|
|
|
int *max_labels_length = GetDeviceAddress<int>(workspace, 7);
|
|
|
|
|
int numclass = SizeToInt(probs_dims_[2]);
|
|
|
|
|
int batch = SizeToInt(probs_dims_[1]);
|
|
|
|
|
int max_time = SizeToInt(probs_dims_[0]);
|
|
|
|
|
int max_sequence = 0;
|
|
|
|
|
CalculateMaxSequence(sequence_length, max_labels_length, batch, stream);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(&max_sequence, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
|
|
|
|
@ -73,11 +124,7 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
InnerSoftMax(probs, softmax_probs, sequence_length, max_time, batch, numclass, stream);
|
|
|
|
|
MemsetForWS(label_value_pcr, cum_labels_length, label_squence_length, costs, grads, stream);
|
|
|
|
|
int max_labels_length_host = 0;
|
|
|
|
|
int batch_label = 0;
|
|
|
|
|
int *label_value_with_blank = nullptr;
|
|
|
|
|
T *log_alpha_b = nullptr;
|
|
|
|
|
T *log_beta_b = nullptr;
|
|
|
|
|
|
|
|
|
|
CalculatePreLength(label_squence_length, precum_labels_length, cum_labels_length, max_labels_length, label_indices,
|
|
|
|
|
batch, label_size_ / sizeof(int), stream);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
@ -97,8 +144,14 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
cudaMemcpyAsync(&max_labels_length_host, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
|
|
|
|
|
"cudaMemcpyAsync failed.");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LaunchSecondHalf(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
|
|
|
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
|
|
|
|
int SOffSet = 2 * max_labels_length_host + 1;
|
|
|
|
|
int log_prob_size = batch * SOffSet * max_time;
|
|
|
|
|
|
|
|
|
|
if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "output size is greater than input size.";
|
|
|
|
|
}
|
|
|
|
@ -124,43 +177,8 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
ignore_longer_outputs_than_inputs_, stream);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
|
|
|
|
FreeMem(label_value_with_blank, log_alpha_b, log_beta_b);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
InitResource();
|
|
|
|
|
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
if (probs_shape.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
probs_dims_[0] = probs_shape[0];
|
|
|
|
|
probs_dims_[1] = probs_shape[1];
|
|
|
|
|
probs_dims_[2] = probs_shape[2];
|
|
|
|
|
auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
|
|
|
|
if (labels_dims.size() != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
if (indice_dims.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "labels indice dims: " << indice_dims.size() << " not support.";
|
|
|
|
|
}
|
|
|
|
|
label_size_ = sizeof(int);
|
|
|
|
|
for (auto i : labels_dims) {
|
|
|
|
|
label_size_ *= i;
|
|
|
|
|
}
|
|
|
|
|
label_indice_size_ = sizeof(int64_t);
|
|
|
|
|
for (auto i : indice_dims) {
|
|
|
|
|
label_indice_size_ *= i;
|
|
|
|
|
}
|
|
|
|
|
auto squence_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
|
|
|
|
squence_lengths_size_ = squence_length_dims[0] * sizeof(int);
|
|
|
|
|
preprocess_collapse_repeated_ = GetAttr<bool>(kernel_node, "preprocess_collapse_repeated");
|
|
|
|
|
ctc_merge_repeated_ = GetAttr<bool>(kernel_node, "ctc_merge_repeated");
|
|
|
|
|
ignore_longer_outputs_than_inputs_ = GetAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs");
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InitSizeLists() override {
|
|
|
|
|
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T));
|
|
|
|
|
input_size_list_.push_back(label_indice_size_);
|
|
|
|
@ -226,6 +244,31 @@ class CtcLossGpuKernel : public GpuKernel {
|
|
|
|
|
bool ctc_merge_repeated_;
|
|
|
|
|
bool ignore_longer_outputs_than_inputs_;
|
|
|
|
|
T kLogZero_ = -std::numeric_limits<T>::infinity();
|
|
|
|
|
|
|
|
|
|
// Heap parameter
|
|
|
|
|
T *probs;
|
|
|
|
|
int64_t *label_indices;
|
|
|
|
|
int *label_values;
|
|
|
|
|
int *sequence_length;
|
|
|
|
|
T *costs;
|
|
|
|
|
T *grads;
|
|
|
|
|
T *softmax_probs;
|
|
|
|
|
int *cum_labels_length;
|
|
|
|
|
int *label_squence_length;
|
|
|
|
|
int *label_value_sp;
|
|
|
|
|
int *label_value_pcr;
|
|
|
|
|
T *prob_num;
|
|
|
|
|
int *precum_labels_length;
|
|
|
|
|
int *max_labels_length;
|
|
|
|
|
int numclass;
|
|
|
|
|
int batch;
|
|
|
|
|
int max_time;
|
|
|
|
|
int max_sequence;
|
|
|
|
|
int max_labels_length_host;
|
|
|
|
|
int batch_label;
|
|
|
|
|
int *label_value_with_blank;
|
|
|
|
|
T *log_alpha_b;
|
|
|
|
|
T *log_beta_b;
|
|
|
|
|
}; // namespace kernel
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|