59 lines
1.2 KiB
C++
59 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include "../macros.h"
|
|
|
|
namespace vision {
|
|
namespace ops {
|
|
|
|
VISION_API at::Tensor roi_align(
|
|
const at::Tensor& input,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
int64_t pooled_height,
|
|
int64_t pooled_width,
|
|
int64_t sampling_ratio,
|
|
bool aligned);
|
|
|
|
VISION_API at::Tensor roi_align_symint(
|
|
const at::Tensor& input,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
c10::SymInt pooled_height,
|
|
c10::SymInt pooled_width,
|
|
int64_t sampling_ratio,
|
|
bool aligned);
|
|
|
|
namespace detail {
|
|
|
|
at::Tensor _roi_align_backward(
|
|
const at::Tensor& grad,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
int64_t pooled_height,
|
|
int64_t pooled_width,
|
|
int64_t batch_size,
|
|
int64_t channels,
|
|
int64_t height,
|
|
int64_t width,
|
|
int64_t sampling_ratio,
|
|
bool aligned);
|
|
|
|
at::Tensor _roi_align_backward_symint(
|
|
const at::Tensor& grad,
|
|
const at::Tensor& rois,
|
|
double spatial_scale,
|
|
c10::SymInt pooled_height,
|
|
c10::SymInt pooled_width,
|
|
c10::SymInt batch_size,
|
|
c10::SymInt channels,
|
|
c10::SymInt height,
|
|
c10::SymInt width,
|
|
int64_t sampling_ratio,
|
|
bool aligned);
|
|
|
|
} // namespace detail
|
|
|
|
} // namespace ops
|
|
} // namespace vision
|