53 lines
1.2 KiB
C++
53 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include "../macros.h"
|
|
|
|
namespace vision {
|
|
namespace ops {
|
|
|
|
VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool(
|
|
const at::Tensor& input,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
int64_t pooled_height,
|
|
int64_t pooled_width);
|
|
|
|
VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool_symint(
|
|
const at::Tensor& input,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
c10::SymInt pooled_height,
|
|
c10::SymInt pooled_width);
|
|
|
|
namespace detail {
|
|
|
|
at::Tensor _roi_pool_backward(
|
|
const at::Tensor& grad,
|
|
const at::Tensor& rois,
|
|
const at::Tensor& argmax,
|
|
double spatial_scale,
|
|
int64_t pooled_height,
|
|
int64_t pooled_width,
|
|
int64_t batch_size,
|
|
int64_t channels,
|
|
int64_t height,
|
|
int64_t width);
|
|
|
|
at::Tensor _roi_pool_backward_symint(
|
|
const at::Tensor& grad,
|
|
const at::Tensor& rois,
|
|
const at::Tensor& argmax,
|
|
double spatial_scale,
|
|
c10::SymInt pooled_height,
|
|
c10::SymInt pooled_width,
|
|
c10::SymInt batch_size,
|
|
c10::SymInt channels,
|
|
c10::SymInt height,
|
|
c10::SymInt width);
|
|
|
|
} // namespace detail
|
|
|
|
} // namespace ops
|
|
} // namespace vision
|