Example #1
0
void testProtoDataProvider(int* numPerSlotType, bool iid, bool async,
                           bool useGpu, bool dataCompression,
                           int numConstantSlots = 0) {
  mkDir(kTestDir);
  DataBatch data;

  prepareData(&data, numPerSlotType, iid, useGpu);
  writeData(data, useGpu, dataCompression);

  DataConfig config;
  config.set_type("proto");
  config.set_files(dataCompression ? kProtoFileListCompressed : kProtoFileList);
  config.set_async_load_data(async);

  for (int i = 0; i < numConstantSlots; ++i) {
    config.add_constant_slots(i + 11);
    MatrixPtr w = Matrix::create(data.getSize(), 1, /* trans= */ false,
                                 /* useGpu= */ false);
    w->assign(config.constant_slots(i));
    data.appendData(w);
  }

  unique_ptr<DataProvider> dataProvider(DataProvider::create(config, useGpu));
  dataProvider->setSkipShuffle();

  EXPECT_EQ(data.getSize(), dataProvider->getSize());

  int64_t batchSize = 10;
  DataBatch batch;

  size_t seq1 = 0;
  vector<Argument>& args1 = data.getStreams();
  ICpuGpuVectorPtr sequenceStartPositions1 =
      args1[0].sequenceStartPositions;

  dataProvider->reset();

  while (dataProvider->getNextBatch(batchSize, &batch) > 0) {
    CHECK_EQ(data.getNumStreams(), batch.getNumStreams());
    vector<Argument>& args2 = batch.getStreams();
    ICpuGpuVectorPtr sequenceStartPositions2 =
        args2[0].sequenceStartPositions;
    for (auto& arg : args2) {
      EXPECT_EQ(iid, !arg.sequenceStartPositions);
    }
    size_t numSeqs = batch.getNumSequences();
    VLOG(1) << "numSeqs=" << numSeqs;
    for (size_t seq2 = 0; seq2 < numSeqs; ++seq1, ++seq2) {
      int64_t begin1 = seq1;
      int64_t end1 = seq1 + 1;
      if (sequenceStartPositions1) {
        begin1 = sequenceStartPositions1->getElement(seq1);
        end1 = sequenceStartPositions1->getElement(seq1 + 1);
        EXPECT_LT(seq1, sequenceStartPositions1->getSize() - 1);
      }

      int64_t begin2 = seq2;
      int64_t end2 = seq2 + 1;
      if (sequenceStartPositions2) {
        begin2 = sequenceStartPositions2->getElement(seq2);
        end2 = sequenceStartPositions2->getElement(seq2 + 1);
      }
      VLOG(1) << " begin1=" << begin1 << " end1=" << end1
              << " begin2=" << begin2 << " end2=" << end2;
      EXPECT_EQ(end1 - begin1, end2 - begin2);
      for (int i = 0; i < end1 - begin1; ++i) {
        checkSample(args1, begin1 + i, args2, begin2 + i, useGpu);
      }
    }
  }

  EXPECT_EQ(seq1, (size_t)data.getNumSequences());
  rmDir(kTestDir);
}