sglang_v0.5.2/pytorch_2.8.0/test/cpp_extensions/xpu_extension.sycl

64 lines
2.0 KiB
Plaintext

#include <c10/xpu/XPUStream.h>
#include <torch/extension.h>
#include <sycl/sycl.hpp>
void sigmoid_add_kernel(const float* x,
const float* y,
float* output,
const int size,
const sycl::nd_item<3> &item_ct1) {
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
if (index < size) {
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
output[index] = sigmoid_x + sigmoid_y;
}
}
class SigmoidAddKernel {
public:
void operator()(const sycl::nd_item<3> &item_ct1) const {
sigmoid_add_kernel(x, y, output, size, item_ct1);
}
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
x(_x),
y(_y),
output(_output),
size(_size)
{}
private:
const float* x;
const float* y;
float* output;
int size;
};
void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
SigmoidAddKernel krn(x, y, output, size);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for<SigmoidAddKernel>(
sycl::nd_range<3>(
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
sycl::range<3>(1, 1, threads)),
krn);
});
}
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
auto output = torch::zeros_like(x);
sigmoid_add_xpu(
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
}