18 lines
320 B
C++
18 lines
320 B
C++
#include <torch/extension.h>
|
|
|
|
struct Doubler {
|
|
Doubler(int A, int B) {
|
|
tensor_ =
|
|
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
|
|
}
|
|
torch::Tensor forward() {
|
|
return tensor_ * 2;
|
|
}
|
|
torch::Tensor get() const {
|
|
return tensor_;
|
|
}
|
|
|
|
private:
|
|
torch::Tensor tensor_;
|
|
};
|