sglang_v0.5.2/pytorch_2.8.0/torch/nativert/detail/ITree.cpp

486 lines
16 KiB
C++

#include <ATen/record_function.h>
#include <torch/nativert/detail/ITree.h>
#include <iterator>
#include <string_view>
#include <ATen/core/ivalue.h>
#include <c10/util/Synchronized.h>
#include <nlohmann/json.hpp>
namespace torch::nativert::detail {
namespace {
inline constexpr int kDefaultTreeSpecSerializationProtocol = 1;
c10::IValue dynamicToIValue(const nlohmann::json& obj) {
if (obj.is_string()) {
return obj.get<std::string>();
} else if (obj.is_number_integer()) {
return obj.get<int64_t>();
} else {
TORCH_CHECK(false, "Unsupported dynamic type: ", obj);
}
}
void itreeFlatten(
const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
if (spec.isIValue()) {
ivalues.push_back(nested);
return;
}
auto flattenFn = spec.nodeDefCache().flattenFn;
flattenFn(nested, spec, ivalues);
}
class PytreeNodeRegistry {
public:
PytreeNodeRegistry() {
// Add some law of physics here.
registerNode(
"builtins.tuple",
NodeDef{
[](const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
const auto& tuple = nested.toTupleRef().elements();
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
for (size_t i = 0; i < tuple.size(); i++) {
itreeFlatten(tuple[i], spec.children(i), ivalues);
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
return c10::ivalue::Tuple::create(std::move(flats));
},
[](ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
const auto& tuple = nested.toTupleRef().elements();
TORCH_CHECK_EQ(tuple.size(), spec.children().size());
for (size_t i = 0; i < tuple.size(); i++) {
ivalueApply(fn, tuple[i], spec.children(i));
}
}});
const auto& tupleNodeDef = getNodeDef("builtins.tuple");
registerNode(
"collections.namedtuple",
NodeDef{
tupleNodeDef.flattenFn,
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
return c10::ivalue::Tuple::create(std::move(flats));
},
tupleNodeDef.ivalueApplyFn,
[](std::string_view context) { return nlohmann::json{context}; }});
registerNode(
"builtins.list",
NodeDef{
[](const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
auto list = nested.toListRef();
for (size_t i = 0; i < list.size(); i++) {
itreeFlatten(list[i], spec.children(i), ivalues);
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
c10::List<c10::IValue> list(c10::AnyType::get());
list.reserve(flats.size());
for (auto& flat : flats) {
list.push_back(std::move(flat));
}
return list;
},
[](ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
auto list = nested.toListRef();
for (size_t i = 0; i < list.size(); i++) {
ivalueApply(fn, list[i], spec.children(i));
}
}});
registerNode(
"torch.fx.immutable_collections.immutable_list",
getNodeDef("builtins.list"));
registerNode(
"builtins.dict",
NodeDef{
[](const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
auto dict = nested.toGenericDict();
const auto& contextKeys = spec.contextKeys();
// allow the dict size less than the spec, missing key will be
// filled with empty tensor
TORCH_CHECK_LE(dict.size(), contextKeys.size());
size_t i = 0;
for (const auto& key : contextKeys) {
auto it = dict.find(key);
if (it != dict.end()) {
itreeFlatten(it->value(), spec.children(i), ivalues);
} else {
// when we have a dict with missing keys, we fill the missing
// ivalues with an empty tensor which is required for
// validation
for (size_t j = 0; j < spec.children(i).numIValues(); ++j) {
at::Tensor empty_tensor;
ivalues.emplace_back(std::move(empty_tensor));
}
}
i++;
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
c10::Dict<c10::IValue, c10::IValue> dict(
c10::AnyType::get(), c10::AnyType::get());
TORCH_CHECK(obj.is_array());
TORCH_CHECK_EQ(obj.size(), flats.size());
dict.reserve(flats.size());
for (size_t i = 0; i < flats.size(); i++) {
dict.insert(dynamicToIValue(obj[i]), std::move(flats[i]));
}
return dict;
},
[](ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
auto dict = nested.toGenericDict();
const auto& contextKeys = spec.contextKeys();
size_t i = 0;
for (const auto& key : contextKeys) {
if (spec.children(i).isUsed()) {
auto it = dict.find(key);
if (it != dict.end()) {
ivalueApply(fn, it->value(), spec.children(i));
} else {
TORCH_CHECK(false, "input arg is missing key ", key);
}
}
i++;
}
}});
registerNode(
"torch.fx.immutable_collections.immutable_dict",
getNodeDef("builtins.dict"));
}
bool hasNodeDef(std::string_view typeName) const {
return registry_.find(std::string{typeName}) != registry_.end();
}
const NodeDef& getNodeDef(std::string_view typeName) const {
return registry_.at(std::string{typeName});
}
void registerNode(std::string_view typeName, NodeDef nodeDef) {
TORCH_CHECK(!hasNodeDef(typeName));
registry_.emplace(typeName, nodeDef);
}
private:
std::unordered_map<std::string, NodeDef> registry_;
};
c10::Synchronized<PytreeNodeRegistry>& getPytreeNodeRegistry() {
static auto* registry = new c10::Synchronized<PytreeNodeRegistry>();
return *registry;
}
ITreeSpec makeITreeSpec(
const nlohmann::json& obj,
const std::vector<const Value*>& values,
int start) {
TORCH_CHECK(obj.is_object());
TORCH_CHECK(obj.find("type") != obj.end());
if (obj["type"].is_null()) {
TORCH_CHECK_EQ(obj["children_spec"].size(), 0);
TORCH_CHECK(obj["context"].is_null());
const Value* value = values[start];
if (value) {
bool isUsed = !value->users().empty();
return ITreeSpec(value, isUsed);
} else {
return ITreeSpec(value, false);
}
}
const auto& name = obj["type"].get<std::string>();
NodeDef nodeDefCache;
getPytreeNodeRegistry().withLock([&](auto& registry) {
TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name);
nodeDefCache = registry.getNodeDef(name);
});
auto context = nodeDefCache.contextLoadFn(obj["context"].get<std::string>());
const auto& childrenSpec = obj["children_spec"];
TORCH_CHECK(childrenSpec.is_array());
std::vector<ITreeSpec> children;
int offset = 0;
for (const auto& child : childrenSpec) {
children.push_back(makeITreeSpec(child, values, start + offset));
// NOLINTNEXTLINE(*-narrowing-conversions)
offset += children.back().numIValues();
}
return ITreeSpec(name, context, std::move(children), nodeDefCache);
}
} // namespace
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) {
getPytreeNodeRegistry().withLock([&](auto& registry) {
registry.registerNode(typeName, std::move(nodeDef));
});
}
ITreeSpec itreeSpecLoads(
std::string_view json,
const std::vector<const Value*>& values) {
const auto obj = nlohmann::json::parse(json);
TORCH_CHECK(obj.is_array());
TORCH_CHECK_EQ(obj.size(), 2);
TORCH_CHECK_EQ(obj[0].get<int64_t>(), kDefaultTreeSpecSerializationProtocol);
auto result = makeITreeSpec(obj[1], values, 0);
TORCH_CHECK_EQ(result.numIValues(), values.size());
return result;
}
c10::IValue itreeUnflatten(
std::vector<c10::IValue> ivalues,
const ITreeSpec& spec) {
RECORD_USER_SCOPE("nativert::itreeUnflatten");
TORCH_CHECK_EQ(ivalues.size(), spec.numIValues());
if (spec.isIValue()) {
return std::move(ivalues[0]);
}
auto unflattenFn = spec.nodeDefCache().unflattenFn;
if (spec.allIValues()) {
return unflattenFn(std::move(ivalues), spec.context());
}
size_t start = 0;
std::vector<c10::IValue> childrenPytrees;
for (const auto& child : spec.children()) {
if (child.isIValue()) {
childrenPytrees.push_back(std::move(ivalues[start]));
start++;
continue;
}
size_t numIValues = child.numIValues();
std::vector<c10::IValue> slice(
// NOLINTNEXTLINE(*-narrowing-conversions)
std::make_move_iterator(ivalues.begin() + start),
// NOLINTNEXTLINE(*-narrowing-conversions)
std::make_move_iterator(ivalues.begin() + start + numIValues));
childrenPytrees.push_back(itreeUnflatten(std::move(slice), child));
start += numIValues;
}
return unflattenFn(std::move(childrenPytrees), spec.context());
}
std::vector<c10::IValue> itreeFlatten(
const c10::IValue& nested,
const ITreeSpec& spec) {
std::vector<c10::IValue> ivalues;
ivalues.reserve(spec.numIValues());
itreeFlatten(nested, spec, ivalues);
return ivalues;
}
std::vector<c10::IValue> itreeFlattenFromArgs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const ITreeSpec& spec) {
RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs");
TORCH_CHECK(!spec.isIValue());
TORCH_CHECK_EQ(spec.children().size(), 2);
std::vector<c10::IValue> ivalues;
ivalues.reserve(spec.numIValues());
const auto& specArgs = spec.children(0);
TORCH_CHECK(!specArgs.isIValue());
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
for (size_t i = 0; i < args.size(); i++) {
itreeFlatten(args[i], specArgs.children(i), ivalues);
}
const auto& specKwargs = spec.children(1);
TORCH_CHECK(!specKwargs.isIValue());
TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size());
for (size_t i = 0; i < specKwargs.context().size(); i++) {
itreeFlatten(
kwargs.at(specKwargs.context()[i].get_ref<const std::string&>()),
specKwargs.children(i),
ivalues);
}
return ivalues;
}
void ivalueApplyFromArgs(
ITreeMapNoReturnFn fn,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const ITreeSpec& spec) {
RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs");
TORCH_CHECK(!spec.isIValue());
TORCH_CHECK_EQ(spec.children().size(), 2);
const auto& specArgs = spec.children(0);
TORCH_CHECK(!specArgs.isIValue());
TORCH_CHECK_EQ(specArgs.children().size(), args.size());
for (size_t i = 0; i < args.size(); i++) {
ivalueApply(fn, args[i], specArgs.children(i));
}
const auto& specKwargs = spec.children(1);
TORCH_CHECK(!specKwargs.isIValue());
const auto& ctx = specKwargs.context();
TORCH_CHECK_EQ(ctx.size(), kwargs.size());
for (size_t i = 0; i < ctx.size(); i++) {
ivalueApply(
fn,
kwargs.at(ctx[i].get_ref<const std::string&>()),
specKwargs.children(i));
}
}
std::vector<at::Tensor> itreeFlattenToTensorList(
const c10::IValue& nested,
const ITreeSpec& spec) {
auto flats = itreeFlatten(nested, spec);
std::vector<at::Tensor> tensors;
tensors.reserve(flats.size());
for (const auto& flat : flats) {
tensors.push_back(flat.toTensor());
}
return tensors;
}
c10::IValue itreeMap(
ITreeMapFn f,
const c10::IValue& nested,
const ITreeSpec& spec) {
const auto flats = itreeFlatten(nested, spec);
std::vector<c10::IValue> mapped;
mapped.reserve(flats.size());
for (const auto& flat : flats) {
mapped.push_back(f(flat));
}
return itreeUnflatten(std::move(mapped), spec);
}
c10::IValue argsToIValue(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
for (const auto& [key, arg] : kwargs) {
dict.insert(key, arg);
}
return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict});
}
std::
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
itreeMapArgs(
ITreeMapFn f,
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs,
const ITreeSpec& spec) {
const auto val = argsToIValue(args, kwargs);
const auto mapVal = itreeMap(f, val, spec);
auto mapArgs =
mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec();
std::unordered_map<std::string, c10::IValue> mapKwargs;
for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) {
mapKwargs.emplace(entry.key().toStringRef(), entry.value());
}
return {std::move(mapArgs), std::move(mapKwargs)};
}
void ivalueApply(
ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
if (spec.isIValue()) {
if (spec.isUsed()) {
fn(nested, spec.value());
}
return;
}
auto ivalueApplyFn = spec.nodeDefCache().ivalueApplyFn;
ivalueApplyFn(fn, nested, spec);
}
nlohmann::json defaultContextLoadFn(std::string_view context) {
return nlohmann::json::parse(context);
}
ITreeSpec::ITreeSpec(
std::string_view uniformName,
nlohmann::json context,
std::vector<ITreeSpec> children,
NodeDef nodeDefCache)
: uniformName_(uniformName),
context_(std::move(context)),
children_(std::move(children)),
nodeDefCache_(nodeDefCache),
numIValues_(0),
value_(nullptr),
isUsed_(false) {
for (auto& child : children_) {
numIValues_ += child.numIValues();
allIValues_ &= child.isIValue();
isUsed_ |= child.isUsed();
}
if (uniformName_ == "builtins.dict" ||
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
for (const auto& keyObj : context_) {
contextKeys_.push_back(dynamicToIValue(keyObj));
}
}
}
c10::TypePtr ITreeSpec::toAtenType() const {
if (isIValue()) {
return c10::AnyType::get();
} else if (uniformName_ == "builtins.tuple") {
std::vector<c10::TypePtr> childrenType;
childrenType.reserve(children_.size());
for (const auto& childrenSpec : children_) {
childrenType.emplace_back(childrenSpec.toAtenType());
}
return c10::TupleType::create(std::move(childrenType));
} else if (
uniformName_ == "builtins.list" ||
uniformName_ == "torch.fx.immutable_collections.immutable_list") {
if (children_.empty()) {
return c10::ListType::create(c10::AnyType::get());
} else {
return c10::ListType::create(children_[0].toAtenType());
}
} else if (
uniformName_ == "builtins.dict" ||
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
if (children_.empty()) {
return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get());
} else {
return c10::DictType::create(
dynamicToIValue(context_[0]).type(), children_[0].toAtenType());
}
} else {
TORCH_CHECK(false, "Unsupported uniform name: ", uniformName());
}
}
} // namespace torch::nativert::detail