sglang_v0.5.2/pytorch_2.8.0/aten/src/ATen/native/Embedding.cpp

216 lines
7.4 KiB
C++

#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/List.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/BinaryOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/embedding_backward_native.h>
#include <ATen/ops/embedding_dense_backward.h>
#include <ATen/ops/embedding_dense_backward_native.h>
#include <ATen/ops/embedding_native.h>
#include <ATen/ops/embedding_renorm_native.h>
#include <ATen/ops/embedding_sparse_backward.h>
#include <ATen/ops/embedding_sparse_backward_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/zeros.h>
#endif
#include <c10/util/irange.h>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
namespace at::native {
Tensor embedding_symint(const Tensor & weight, const Tensor & indices,
c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
TORCH_CHECK(weight.dim() == 2, "'weight' must be 2-D");
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding", indices_arg, {kLong, kInt});
// TODO: use tensor.index() after improving perf
if (indices.dim() == 1) {
return weight.index_select(0, indices);
}
auto size = indices.sym_sizes().vec();
for (const auto& d : weight.sym_sizes().slice(1)) {
size.push_back(d);
}
return weight.index_select(0, indices.reshape(-1)).view_symint(size);
}
Tensor embedding_backward_symint(
const Tensor & grad, const Tensor & indices, c10::SymInt num_weights,
c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
if (sparse) {
// TODO: if we teach sparse tensor how to propagate symints, the guard
// here is not strictly necessary. However, we think it is fine as is
// because num weights is derived from a parameter and therefore
// typically not varying.
return at::embedding_sparse_backward(
grad, indices,
num_weights.guard_int(__FILE__, __LINE__),
padding_idx.guard_int(__FILE__, __LINE__),
scale_grad_by_freq);
} else {
return at::embedding_dense_backward_symint(
grad, indices, std::move(num_weights), padding_idx, scale_grad_by_freq);
}
}
Tensor embedding_sparse_backward(
const Tensor & grad_, const Tensor & indices_, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices_, "indices", 2);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
// TODO: implement scale_grad_by_freq
if (scale_grad_by_freq) {
TORCH_CHECK(false,
"embedding_backward: scale_grad_by_freq not supported with sparse gradients");
}
Tensor indices = indices_;
Tensor grad = grad_;
if (padding_idx != -1) {
c10::List<std::optional<Tensor>> c({indices != padding_idx});
indices = indices.index(c);
grad = grad.index(c);
}
auto num_features = grad_.sym_size(-1);
auto weight_size = std::array<c10::SymInt, 2>{{ num_weights, num_features }};
auto dense_options = grad.options();
// check if all our grad come from padding_idx
if (grad.sym_numel() == 0) {
return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)),
at::empty_symint({c10::SymInt(0), std::move(num_features)}, dense_options),
weight_size);
}
auto index = indices.reshape({1, -1});
auto values = grad.reshape_symint({c10::SymInt(-1), std::move(num_features)});
return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size);
}
Tensor embedding_dense_backward_cpu(
const Tensor & grad_, const Tensor & indices, int64_t num_weights,
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices, "indices", 2);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
auto indices_contig = indices.contiguous();
int64_t numel = indices.numel();
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
auto add_iter = TensorIteratorConfig()
.add_output(grad_weight)
.add_input(grad_weight)
.add_const_input(grad)
.resize_outputs(false)
.declare_static_shape(grad.sizes(), /*squash_dims=*/0)
.build();
const auto gW_data = reinterpret_cast<char*>(grad_weight.data_ptr());
const auto gO_data = reinterpret_cast<const char*>(grad.const_data_ptr());
const auto gW_stride = grad_weight.strides()[0] * grad_weight.element_size();
const auto gO_stride = grad.strides()[0] * grad.element_size();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
auto indices_data = indices_contig.const_data_ptr<index_t>();
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<index_t[]> counts;
if (scale_grad_by_freq) {
counts.reset(new index_t[num_weights]);
for (const auto i : c10::irange(numel)) {
counts[indices_data[i]] = 0;
}
for (const auto i : c10::irange(numel)) {
counts[indices_data[i]]++;
}
}
auto parallel_section = [&](index_t start, index_t end) {
TensorIterator iter(add_iter);
for (const auto i : c10::irange(numel)) {
if (indices_data[i] != padding_idx) {
index_t k = indices_data[i];
if (k >= start && k < end) {
double scale = 1.0;
if (scale_grad_by_freq) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scale /= counts[k];
}
// grad_weight[k].add_(grad[i], scale);
iter.unsafe_replace_operand(0, gW_data + k * gW_stride);
iter.unsafe_replace_operand(1, gW_data + k * gW_stride);
iter.unsafe_replace_operand(2, const_cast<char*>(gO_data + i * gO_stride));
add_stub(kCPU, iter, scale);
}
}
}
};
at::parallel_for(0, num_weights, 1000, parallel_section);
});
return grad_weight;
}
Tensor & embedding_renorm_cpu_(
Tensor & self, const Tensor & indices, double max_norm, double norm_type) {
auto self_arg = TensorArg(self, "self", 1);
auto indices_arg = TensorArg(indices, "indices", 2);
checkDim("embedding_renorm_", self_arg, 2);
checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
auto indices_contig = indices.contiguous();
auto num_indices = indices.numel();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
auto data_ptr = indices_contig.const_data_ptr<index_t>();
auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
std::sort(sorted_indices.begin(), sorted_indices.end());
// Note that we cannot use at::parallel_for here because we perform operations on
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
for (const auto i : c10::irange(num_indices)) {
if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
continue;
}
auto row = self[sorted_indices[i]];
auto norm = row.norm(norm_type).item<double>();
if (norm > max_norm) {
auto scale = max_norm / (norm + 1e-7);
row *= scale;
}
}
});
return self;
}
} // namespace at::native