Esempio n. 1
0
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
                              const Tensor &offsets__,
                              const Tensor &offset2bag__,
                              const Tensor &bag_size_,
                              const Tensor &max_indices_,
                              int64_t num_weights,
                              bool scale_grad_by_freq, int64_t mode,
                              bool sparse) {
  auto indices_arg = TensorArg(indices__, "indices__", 1);
  checkScalarType("embedding_bag", indices_arg, kLong);
  auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
  checkScalarType("embedding_bag", offsets_arg, kLong);
  auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
  checkScalarType("embedding_bag", offset2bag_arg, kLong);
  checkContiguous("embedding_bag", offset2bag_arg);
  Tensor indices = indices__.contiguous();
  Tensor offsets = offsets__.contiguous();

  if (sparse) {
    return at::embedding_bag_sparse_backward(
        grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
        scale_grad_by_freq, mode);
  } else {
    return at::embedding_bag_dense_backward(
        grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights,
        scale_grad_by_freq, mode);
  }
}
Esempio n. 2
0
void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) {
  for (auto& t : ts) {
    if (!t->defined()) continue;
    checkContiguous(c, t);
  }
}
Esempio n. 3
0
Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
                                  const Tensor &offsets__,
                                  const Tensor &offset2bag__,
                                  const Tensor &bag_size_,
                                  const Tensor& max_indices_, int64_t num_weights,
                                  bool scale_grad_by_freq, int64_t mode) {
  auto grad = grad_.contiguous();
  auto grad_arg = TensorArg(grad, "grad_", 1);
  checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble});
  auto indices_arg = TensorArg(indices__, "indices__", 1);
  checkScalarType("embedding_bag", indices_arg, kLong);
  auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
  checkScalarType("embedding_bag", offsets_arg, kLong);
  auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
  checkScalarType("embedding_bag", offset2bag_arg, kLong);
  checkContiguous("embedding_bag", offset2bag_arg);
  Tensor indices_ = indices__.contiguous();
  Tensor offsets_ = offsets__.contiguous();

  Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);

  auto ind_sort_ = indices_.sort();
  auto indices = std::get<0>(ind_sort_);
  auto ind_sort = std::get<1>(ind_sort_);
  auto offset2bag = offset2bag_.index_select(0, ind_sort);

  auto indices_data = indices.data<int64_t>();
  auto offsets_data = offsets_.data<int64_t>();
  auto offset2bag_data = offset2bag.data<int64_t>();
  int64_t numel = indices.numel();

  std::vector<int64_t> counts(num_weights);
  for (int i = 0; i < numel; i++) {
    counts[indices_data[i]] = 0;
  }
  for (int i = 0; i < numel; i++) {
    counts[indices_data[i]]++;
  }

  auto index_grad_weight =
      at::zeros({num_weights, grad.size(1)}, grad.type()).contiguous();

  std::vector<int64_t> counts_uniq;
  counts_uniq.reserve(num_weights);
  int64_t o = 0;
  for (int64_t i = 0; i < numel; i += counts[indices_data[i]]) {
    counts_uniq.push_back(counts[indices_data[i]]);
    if (o > 0) {
      counts_uniq[o] += counts_uniq[o - 1];
    }
    o++;
  }

  if (mode == MODE_MEAN || mode == MODE_SUM) {
    #pragma omp parallel for if (numel > 1000)
      for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
        int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
        int64_t index = indices_data[start];
        for (int64_t j = start; j < counts_uniq[i]; j++) {
          int64_t source = offset2bag_data[j];
          double scale = 1.0;
          if (scale_grad_by_freq) {
            scale /= counts[indices_data[i]];
          }
          if (mode == 1) { // MODE_MEAN
            if (offsets_.size(0) == 1) {
              auto bag_size = indices.size(0);
              scale /= bag_size;
            } else {
              if (source == offsets_.size(0) - 1) {
                scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
              } else {
                scale /= offsets_data[source + 1] - offsets_data[source];
              }
            }
          }
          int64_t ddim = grad.size(1);
          if (grad.type().scalarType() == kFloat) {
            auto igwd = index_grad_weight.data<float>();
            auto gd = grad.data<float>();
            THBlas_axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
                        igwd + ddim * index, 1);
          } else if (grad.type().scalarType() == kDouble) {
            auto igwd = index_grad_weight.data<double>();
            auto gd = grad.data<double>();
            THBlas_axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
                         igwd + ddim * index, 1);
          }
        }
      }
  } else if (mode == MODE_MAX) {
    auto nonempty_max_indices = max_indices_.index_select(0, bag_size_.nonzero().view(-1));
    auto nonempty_grad = grad_.index_select(0, bag_size_.nonzero().view(-1));

    for (int64_t dim = 0; dim < grad.size(1); dim++) {
      index_grad_weight.select(1, dim).index_add_(
        0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
    }
  }

  return index_grad_weight;
}