207 lines
4.4 KiB
C++
207 lines
4.4 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace c10 {
|
|
|
|
template <typename T>
|
|
class IntrusiveList;
|
|
|
|
class IntrusiveListHook {
|
|
template <typename P, typename T>
|
|
friend class ListIterator;
|
|
|
|
template <typename T>
|
|
friend class IntrusiveList;
|
|
|
|
IntrusiveListHook* next_{nullptr};
|
|
IntrusiveListHook* prev_{nullptr};
|
|
|
|
void link_before(IntrusiveListHook* next_node) {
|
|
next_ = next_node;
|
|
prev_ = next_node->prev_;
|
|
next_node->prev_ = this;
|
|
prev_->next_ = this;
|
|
}
|
|
|
|
public:
|
|
IntrusiveListHook() : next_(this), prev_(this) {}
|
|
|
|
IntrusiveListHook(const IntrusiveListHook&) = delete;
|
|
IntrusiveListHook& operator=(const IntrusiveListHook&) = delete;
|
|
IntrusiveListHook(IntrusiveListHook&&) = delete;
|
|
IntrusiveListHook& operator=(IntrusiveListHook&&) = delete;
|
|
|
|
void unlink() {
|
|
TORCH_CHECK(is_linked());
|
|
next_->prev_ = prev_;
|
|
prev_->next_ = next_;
|
|
next_ = this;
|
|
prev_ = this;
|
|
}
|
|
|
|
~IntrusiveListHook() {
|
|
if (is_linked()) {
|
|
unlink();
|
|
}
|
|
}
|
|
|
|
bool is_linked() const {
|
|
return next_ != this;
|
|
}
|
|
};
|
|
|
|
template <typename P, typename T>
|
|
class ListIterator {
|
|
static_assert(std::is_same_v<std::remove_const_t<P>, IntrusiveListHook>);
|
|
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
|
|
P* ptr_;
|
|
|
|
friend class IntrusiveList<T>;
|
|
|
|
public:
|
|
using iterator_category = std::bidirectional_iterator_tag;
|
|
using value_type = std::conditional_t<std::is_const_v<P>, const T, T>;
|
|
using difference_type = std::ptrdiff_t;
|
|
using pointer = value_type*;
|
|
using reference = value_type&;
|
|
|
|
explicit ListIterator(P* ptr) : ptr_(ptr) {}
|
|
~ListIterator() = default;
|
|
|
|
ListIterator(const ListIterator&) = default;
|
|
ListIterator& operator=(const ListIterator&) = default;
|
|
ListIterator(ListIterator&&) = default;
|
|
ListIterator& operator=(ListIterator&&) = default;
|
|
|
|
template <
|
|
typename Q,
|
|
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
|
|
ListIterator(const ListIterator<Q, T>& rhs) : ptr_(rhs.ptr_) {}
|
|
|
|
template <
|
|
typename Q,
|
|
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
|
|
ListIterator& operator=(const ListIterator<Q, T>& rhs) {
|
|
ptr_ = rhs.ptr_;
|
|
return *this;
|
|
}
|
|
|
|
template <typename Q>
|
|
bool operator==(const ListIterator<Q, T>& other) const {
|
|
return ptr_ == other.ptr_;
|
|
}
|
|
|
|
template <typename Q>
|
|
bool operator!=(const ListIterator<Q, T>& other) const {
|
|
return !(*this == other);
|
|
}
|
|
|
|
auto& operator*() const {
|
|
return static_cast<reference>(*ptr_);
|
|
}
|
|
|
|
ListIterator& operator++() {
|
|
TORCH_CHECK(ptr_);
|
|
ptr_ = ptr_->next_;
|
|
return *this;
|
|
}
|
|
|
|
ListIterator& operator--() {
|
|
TORCH_CHECK(ptr_);
|
|
ptr_ = ptr_->prev_;
|
|
return *this;
|
|
}
|
|
|
|
auto* operator->() const {
|
|
return static_cast<pointer>(ptr_);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
class IntrusiveList {
|
|
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
|
|
|
|
public:
|
|
IntrusiveList() = default;
|
|
IntrusiveList(const std::initializer_list<std::reference_wrapper<T>>& items) {
|
|
for (auto& item : items) {
|
|
insert(this->end(), item);
|
|
}
|
|
}
|
|
~IntrusiveList() {
|
|
while (head_.is_linked()) {
|
|
head_.next_->unlink();
|
|
}
|
|
}
|
|
IntrusiveList(const IntrusiveList&) = delete;
|
|
IntrusiveList& operator=(const IntrusiveList&) = delete;
|
|
IntrusiveList(IntrusiveList&&) = delete;
|
|
IntrusiveList& operator=(IntrusiveList&&) = delete;
|
|
|
|
using iterator = ListIterator<IntrusiveListHook, T>;
|
|
using const_iterator = ListIterator<const IntrusiveListHook, T>;
|
|
|
|
auto begin() const {
|
|
return ++const_iterator{&head_};
|
|
}
|
|
|
|
auto begin() {
|
|
return ++iterator{&head_};
|
|
}
|
|
|
|
auto end() const {
|
|
return const_iterator{&head_};
|
|
}
|
|
|
|
auto end() {
|
|
return iterator{&head_};
|
|
}
|
|
|
|
auto rbegin() const {
|
|
return std::reverse_iterator{end()};
|
|
}
|
|
|
|
auto rbegin() {
|
|
return std::reverse_iterator{end()};
|
|
}
|
|
|
|
auto rend() const {
|
|
return std::reverse_iterator{begin()};
|
|
}
|
|
|
|
auto rend() {
|
|
return std::reverse_iterator{begin()};
|
|
}
|
|
|
|
auto iterator_to(const T& n) const {
|
|
return const_iterator{&n};
|
|
}
|
|
|
|
auto iterator_to(T& n) {
|
|
return iterator{&n};
|
|
}
|
|
|
|
iterator insert(iterator pos, T& n) {
|
|
n.link_before(pos.ptr_);
|
|
return iterator{&n};
|
|
}
|
|
|
|
size_t size() const {
|
|
size_t ret = 0;
|
|
for ([[maybe_unused]] auto& _ : *this) {
|
|
ret++;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
bool empty() const {
|
|
return !head_.is_linked();
|
|
}
|
|
|
|
private:
|
|
IntrusiveListHook head_;
|
|
};
|
|
|
|
} // namespace c10
|