64 lines
2.0 KiB
Plaintext
64 lines
2.0 KiB
Plaintext
#include <c10/xpu/XPUStream.h>
|
|
#include <torch/extension.h>
|
|
#include <sycl/sycl.hpp>
|
|
|
|
void sigmoid_add_kernel(const float* x,
|
|
const float* y,
|
|
float* output,
|
|
const int size,
|
|
const sycl::nd_item<3> &item_ct1) {
|
|
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
item_ct1.get_local_id(2);
|
|
if (index < size) {
|
|
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
|
|
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
|
|
output[index] = sigmoid_x + sigmoid_y;
|
|
}
|
|
}
|
|
|
|
class SigmoidAddKernel {
|
|
public:
|
|
void operator()(const sycl::nd_item<3> &item_ct1) const {
|
|
sigmoid_add_kernel(x, y, output, size, item_ct1);
|
|
}
|
|
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
|
|
x(_x),
|
|
y(_y),
|
|
output(_output),
|
|
size(_size)
|
|
{}
|
|
private:
|
|
const float* x;
|
|
const float* y;
|
|
float* output;
|
|
int size;
|
|
};
|
|
|
|
void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
|
|
SigmoidAddKernel krn(x, y, output, size);
|
|
const int threads = 1024;
|
|
const int blocks = (size + threads - 1) / threads;
|
|
|
|
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
|
|
queue.submit([&](sycl::handler &cgh) {
|
|
cgh.parallel_for<SigmoidAddKernel>(
|
|
sycl::nd_range<3>(
|
|
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
|
|
sycl::range<3>(1, 1, threads)),
|
|
krn);
|
|
});
|
|
}
|
|
|
|
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
|
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
|
|
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
|
|
auto output = torch::zeros_like(x);
|
|
sigmoid_add_xpu(
|
|
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
|
|
return output;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
|
}
|