int ctc_check_result(ctcStatus_t retcode, const char * msg) { if( CTC_STATUS_SUCCESS != retcode ) { // Get error message from underlying library const char * ctc_msg = ctcGetStatusString( retcode ); PyErr_Format( PyExc_RuntimeError, "ConnectionistTemporalClassification: %s CTC error: %s", msg, ctc_msg ); return 1; } return 0; }
void WarpCTCLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* const activations = bottom[0]->cpu_data(); Dtype* gradients = bottom[0]->mutable_cpu_diff(); const int alphabet_size = C_; const int minibatch = N_; vector<Dtype> costs(N_); flat_labels_.clear(); if (bottom.size() == 2) {//bottom[0]=activations, bottom[1] is labels, shape: Batchsize*seq len const Blob<Dtype>* label_seq_blob = bottom[1]; const Dtype *label_seq_d = label_seq_blob->cpu_data(); int label_len_per_batch = label_seq_blob->channels(); for (int n = 0; n < N_; ++n) { int curlen = 0; for (int l = 0; l < label_len_per_batch; ++l) { int label = label_seq_d[n*label_len_per_batch + l]; if(label == blank_index_) continue; flat_labels_.push_back(label); curlen++; } label_lengths_[n] = curlen; input_lengths_[n] = T_; } } else if (bottom.size() == 3) { ExtractInputData(bottom[1], bottom[2], &flat_labels_, &label_lengths_, &input_lengths_); } else if (bottom.size() == 4) { const Blob<Dtype>* seq_len_blob = bottom[1]; const Blob<Dtype>* lab_len_blob = bottom[2]; const Blob<Dtype>* label_seq_blob = bottom[3]; const Dtype *seq_len_d = seq_len_blob->cpu_data(); const Dtype *lab_len_d = lab_len_blob->cpu_data(); const Dtype *label_seq_d = label_seq_blob->cpu_data(); int accumulated = 0; CHECK_EQ(seq_len_blob->count(), lab_len_blob->count()); for (int i = 0; i < seq_len_blob->count(); ++i) { label_lengths_[i] = lab_len_d[i]; input_lengths_[i] = seq_len_d[i]; accumulated += lab_len_d[i]; } flat_labels_.clear(); flat_labels_.reserve(accumulated); for (int n = 0; n < N_; ++n) { for (int t = 0; t < label_lengths_[n]; ++t) { flat_labels_.push_back(label_seq_d[label_seq_blob->offset(t, n)]); } } } else { LOG(FATAL) << "Unsupported blobs shape"; } //remove repeat blank labels size_t workspace_alloc_bytes_; ctcOptions options; options.loc = CTC_CPU; options.num_threads = 8; options.blank_label = blank_index_; ctcStatus_t status = get_workspace_size<Dtype>(label_lengths_.data(), input_lengths_.data(), alphabet_size, minibatch, options, &workspace_alloc_bytes_); CHECK_EQ(status, CTC_STATUS_SUCCESS) << "CTC Error: " << ctcGetStatusString(status); if (!workspace_ || workspace_->size() < workspace_alloc_bytes_) { workspace_.reset(new SyncedMemory(workspace_alloc_bytes_ * sizeof(char))); } status = compute_ctc_loss_cpu(activations, gradients, flat_labels_.data(), label_lengths_.data(), input_lengths_.data(), alphabet_size, minibatch, costs.data(), workspace_->mutable_cpu_data(), options ); CHECK_EQ(status, CTC_STATUS_SUCCESS) << "CTC Error: " << ctcGetStatusString(status); // output loss Dtype &loss = top[0]->mutable_cpu_data()[0]; loss = 0; int num = 0; for (int n = 0; n < N_; ++n) { if (costs[n] < std::numeric_limits<Dtype>::infinity()) { loss += costs[n]; ++num; } } loss /= num; int gcnt = bottom[0]->count(); Dtype sumg = 0; for (int i=0;i<gcnt;i++) { sumg += fabs(gradients[i]); } //LOG(INFO) << "mean ctc loss=" << loss << ",N_="<<N_<<",num="<<num << ", mean gradients="<<sumg/gcnt; }