示例#1
0
int ctc_check_result(ctcStatus_t retcode, const char * msg)
{
    if( CTC_STATUS_SUCCESS != retcode )
    {
        // Get error message from underlying library
        const char * ctc_msg = ctcGetStatusString( retcode );

        PyErr_Format( PyExc_RuntimeError,
                      "ConnectionistTemporalClassification: %s CTC error: %s",
                      msg,
                      ctc_msg );
        return 1;
    }
    return 0;
}
void WarpCTCLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
                                      const vector<Blob<Dtype>*>& top) {
    const Dtype* const activations = bottom[0]->cpu_data();
    Dtype* gradients = bottom[0]->mutable_cpu_diff();
    const int alphabet_size = C_;
    const int minibatch = N_;
    vector<Dtype> costs(N_);

	flat_labels_.clear();
	if (bottom.size() == 2) {//bottom[0]=activations, bottom[1] is labels, shape: Batchsize*seq len
		const Blob<Dtype>* label_seq_blob = bottom[1];
		const Dtype *label_seq_d = label_seq_blob->cpu_data();
		int label_len_per_batch = label_seq_blob->channels();

		for (int n = 0; n < N_; ++n) 
		{
			int curlen = 0;
			for (int l = 0; l < label_len_per_batch; ++l)
			{
				int label = label_seq_d[n*label_len_per_batch + l];
				if(label == blank_index_)
					continue;
				flat_labels_.push_back(label);
				curlen++;
			}
			label_lengths_[n] = curlen;
			input_lengths_[n] = T_;
		}
	}
    else if (bottom.size() == 3) {
      ExtractInputData(bottom[1], bottom[2],
          &flat_labels_, &label_lengths_, &input_lengths_);
    } else if (bottom.size() == 4) {
      const Blob<Dtype>* seq_len_blob = bottom[1];
      const Blob<Dtype>* lab_len_blob = bottom[2];
      const Blob<Dtype>* label_seq_blob = bottom[3];

      const Dtype *seq_len_d = seq_len_blob->cpu_data();
      const Dtype *lab_len_d = lab_len_blob->cpu_data();
      const Dtype *label_seq_d = label_seq_blob->cpu_data();

      int accumulated = 0;
      CHECK_EQ(seq_len_blob->count(), lab_len_blob->count());
      for (int i = 0; i < seq_len_blob->count(); ++i) {
        label_lengths_[i] = lab_len_d[i];
        input_lengths_[i] = seq_len_d[i];
        accumulated += lab_len_d[i];
      }

      flat_labels_.clear();
      flat_labels_.reserve(accumulated);
      for (int n = 0; n < N_; ++n) {
        for (int t = 0; t < label_lengths_[n]; ++t) {
          flat_labels_.push_back(label_seq_d[label_seq_blob->offset(t, n)]);
        }
      }
    } else {
      LOG(FATAL) << "Unsupported blobs shape";
    }

	//remove repeat blank labels


    size_t workspace_alloc_bytes_;

    ctcOptions options;
    options.loc = CTC_CPU;
    options.num_threads = 8;
    options.blank_label = blank_index_;

    ctcStatus_t status = get_workspace_size<Dtype>(label_lengths_.data(),
                                            input_lengths_.data(),
                                            alphabet_size,
                                            minibatch,
                                            options,
                                            &workspace_alloc_bytes_);
    CHECK_EQ(status, CTC_STATUS_SUCCESS) << "CTC Error: " << ctcGetStatusString(status);

    if (!workspace_ || workspace_->size() < workspace_alloc_bytes_) {
      workspace_.reset(new SyncedMemory(workspace_alloc_bytes_ * sizeof(char)));
    }

    status = compute_ctc_loss_cpu(activations,
                              gradients,
                              flat_labels_.data(),
                              label_lengths_.data(),
                              input_lengths_.data(),
                              alphabet_size,
                              minibatch,
                              costs.data(),
                              workspace_->mutable_cpu_data(),
                              options
                              );

    CHECK_EQ(status, CTC_STATUS_SUCCESS) << "CTC Error: " << ctcGetStatusString(status);


    // output loss
    Dtype &loss = top[0]->mutable_cpu_data()[0];
    loss = 0;
    int num = 0;
    for (int n = 0; n < N_; ++n) {
      if (costs[n] < std::numeric_limits<Dtype>::infinity()) {
        loss += costs[n];
        ++num;
      }
    }
	
    loss /= num;

	int gcnt = bottom[0]->count();
	Dtype sumg = 0;
	for (int i=0;i<gcnt;i++)
	{
		sumg += fabs(gradients[i]);
	}
	//LOG(INFO) << "mean ctc loss=" << loss << ",N_="<<N_<<",num="<<num << ", mean gradients="<<sumg/gcnt;
}