Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__, const Tensor &offsets__, const Tensor &offset2bag__, const Tensor &bag_size_, const Tensor &max_indices_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) { auto indices_arg = TensorArg(indices__, "indices__", 1); checkScalarType("embedding_bag", indices_arg, kLong); auto offsets_arg = TensorArg(offsets__, "offsets__", 1); checkScalarType("embedding_bag", offsets_arg, kLong); auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1); checkScalarType("embedding_bag", offset2bag_arg, kLong); checkContiguous("embedding_bag", offset2bag_arg); Tensor indices = indices__.contiguous(); Tensor offsets = offsets__.contiguous(); if (sparse) { return at::embedding_bag_sparse_backward( grad_, indices, offsets, offset2bag__, bag_size_, num_weights, scale_grad_by_freq, mode); } else { return at::embedding_bag_dense_backward( grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode); } }
void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) { for (auto& t : ts) { if (!t->defined()) continue; checkContiguous(c, t); } }
Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__, const Tensor &offsets__, const Tensor &offset2bag__, const Tensor &bag_size_, const Tensor& max_indices_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) { auto grad = grad_.contiguous(); auto grad_arg = TensorArg(grad, "grad_", 1); checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble}); auto indices_arg = TensorArg(indices__, "indices__", 1); checkScalarType("embedding_bag", indices_arg, kLong); auto offsets_arg = TensorArg(offsets__, "offsets__", 1); checkScalarType("embedding_bag", offsets_arg, kLong); auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1); checkScalarType("embedding_bag", offset2bag_arg, kLong); checkContiguous("embedding_bag", offset2bag_arg); Tensor indices_ = indices__.contiguous(); Tensor offsets_ = offsets__.contiguous(); Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__); auto ind_sort_ = indices_.sort(); auto indices = std::get<0>(ind_sort_); auto ind_sort = std::get<1>(ind_sort_); auto offset2bag = offset2bag_.index_select(0, ind_sort); auto indices_data = indices.data<int64_t>(); auto offsets_data = offsets_.data<int64_t>(); auto offset2bag_data = offset2bag.data<int64_t>(); int64_t numel = indices.numel(); std::vector<int64_t> counts(num_weights); for (int i = 0; i < numel; i++) { counts[indices_data[i]] = 0; } for (int i = 0; i < numel; i++) { counts[indices_data[i]]++; } auto index_grad_weight = at::zeros({num_weights, grad.size(1)}, grad.type()).contiguous(); std::vector<int64_t> counts_uniq; counts_uniq.reserve(num_weights); int64_t o = 0; for (int64_t i = 0; i < numel; i += counts[indices_data[i]]) { counts_uniq.push_back(counts[indices_data[i]]); if (o > 0) { counts_uniq[o] += counts_uniq[o - 1]; } o++; } if (mode == MODE_MEAN || mode == MODE_SUM) { #pragma omp parallel for if (numel > 1000) for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) { int64_t start = i == 0 ? 0 : counts_uniq[i - 1]; int64_t index = indices_data[start]; for (int64_t j = start; j < counts_uniq[i]; j++) { int64_t source = offset2bag_data[j]; double scale = 1.0; if (scale_grad_by_freq) { scale /= counts[indices_data[i]]; } if (mode == 1) { // MODE_MEAN if (offsets_.size(0) == 1) { auto bag_size = indices.size(0); scale /= bag_size; } else { if (source == offsets_.size(0) - 1) { scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1]; } else { scale /= offsets_data[source + 1] - offsets_data[source]; } } } int64_t ddim = grad.size(1); if (grad.type().scalarType() == kFloat) { auto igwd = index_grad_weight.data<float>(); auto gd = grad.data<float>(); THBlas_axpy<float>(ddim, (float)scale, gd + ddim * source, 1, igwd + ddim * index, 1); } else if (grad.type().scalarType() == kDouble) { auto igwd = index_grad_weight.data<double>(); auto gd = grad.data<double>(); THBlas_axpy<double>(ddim, (double)scale, gd + ddim * source, 1, igwd + ddim * index, 1); } } } } else if (mode == MODE_MAX) { auto nonempty_max_indices = max_indices_.index_select(0, bag_size_.nonzero().view(-1)); auto nonempty_grad = grad_.index_select(0, bag_size_.nonzero().view(-1)); for (int64_t dim = 0; dim < grad.size(1); dim++) { index_grad_weight.select(1, dim).index_add_( 0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim)); } } return index_grad_weight; }