sglang_v0.5.2/pytorch_2.8.0/test/cpp_extensions/doubler.h

18 lines
320 B
C++

#include <torch/extension.h>
struct Doubler {
Doubler(int A, int B) {
tensor_ =
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
}
torch::Tensor forward() {
return tensor_ * 2;
}
torch::Tensor get() const {
return tensor_;
}
private:
torch::Tensor tensor_;
};