Exemplo n.º 1
0
void getColRow(const Argument& arg, int64_t pos, bool useGpu, int* colNum,
               const int** rowCols, const real** rowValues) {
  SlotDef::SlotType type = getSlotType(arg);
  GpuSparseMatrixPtr matGpu;
  CpuSparseMatrixPtr matCpu;
  if (useGpu) {
    matGpu = dynamic_pointer_cast<GpuSparseMatrix>(arg.value);
    ASSERT_TRUE(matGpu != NULL);
  } else {
    matCpu = dynamic_pointer_cast<CpuSparseMatrix>(arg.value);
    ASSERT_TRUE(matCpu != NULL);
  }
  *colNum = useGpu ? matGpu->getColNum(pos) : matCpu->getColNum(pos);
  *rowCols = useGpu ? matGpu->getRowCols(pos) : matCpu->getRowCols(pos);
  if (type == SlotDef::VECTOR_SPARSE_VALUE) {
    *rowValues = useGpu ? matGpu->getRowValues(pos) : matCpu->getRowValues(pos);
  } else {
    *rowValues = NULL;
  }
}
Exemplo n.º 2
0
  void prepareSamples() {
    CHECK(!useGpu_) << "GPU is not supported";

    int batchSize = getInput(*labelLayer_).getBatchSize();
    IVectorPtr label = getInput(*labelLayer_).ids;

    CpuSparseMatrixPtr multiLabel = std::dynamic_pointer_cast<CpuSparseMatrix>(
        getInput(*labelLayer_).value);

    CHECK(label || multiLabel)
        << "The label layer must have ids or NonValueSparseMatrix value";

    auto& randEngine = ThreadLocalRandomEngine::get();

    samples_.clear();
    samples_.reserve(batchSize * (1 + config_.num_neg_samples()));

    real* weight =
        weightLayer_ ? getInputValue(*weightLayer_)->getData() : nullptr;

    for (int i = 0; i < batchSize; ++i) {
      real w = weight ? weight[i] : 1;
      if (label) {
        int* ids = label->getData();
        samples_.push_back({i, ids[i], true, w});
      } else {
        const int* cols = multiLabel->getRowCols(i);
        int n = multiLabel->getColNum(i);
        for (int j = 0; j < n; ++j) {
          samples_.push_back({i, cols[j], true, w});
        }
      }
      for (int j = 0; j < config_.num_neg_samples(); ++j) {
        int id = sampler_ ? sampler_->gen(randEngine) : rand_(randEngine);
        samples_.push_back({i, id, false, w});
      }
    }
    prepared_ = true;
  }