sglang_v0.5.2/pytorch_2.8.0/torch/nativert/graph/Serialization.cpp

551 lines
20 KiB
C++

#include <fmt/format.h>
#include <fmt/ostream.h>
#include <fmt/ranges.h>
#include <torch/nativert/graph/Serialization.h>
#include <limits>
namespace torch::nativert {
namespace {
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& jsonGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata);
Value* symbolicToValue(
const torch::_export::Argument& arg,
Graph& graph,
Node* insertBefore) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
return graph.getValue(arg.get_as_tensor().get_name());
case torch::_export::Argument::Tag::AS_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_tensors()) {
listValue.push_back(graph.getValue(listEl.get_name()));
}
auto listPack =
graph.createListPack(std::move(listValue), Type::Kind::Tensor);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_optional_tensors()) {
switch (listEl.tag()) {
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
listValue.push_back(
graph.getValue(listEl.get_as_tensor().get_name()));
break;
}
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
listValue.push_back(
graph.addValue(std::nullopt, Type::Kind::None, nullptr));
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unknown OptionalTensorArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack = graph.createOptionalListPack(std::move(listValue));
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
return graph.getValue(arg.get_as_sym_int().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
// Need to insert a list pack node
std::vector<Value*> listValue;
for (const auto& listEl : arg.get_as_sym_ints()) {
switch (listEl.tag()) {
case torch::_export::SymIntArgument::Tag::AS_NAME: {
listValue.push_back(graph.getValue(listEl.get_as_name()));
break;
}
case torch::_export::SymIntArgument::Tag::AS_INT: {
// These are concrete int values in the SymIntList, e.g [s0, 8]
// We convert them into a constant Value in graph. These value
// doesn't have producer node
int64_t value = listEl.get_as_int();
TORCH_CHECK(
value >= std::numeric_limits<int>::min() &&
value <= std::numeric_limits<int>::max());
Value* symintValue =
graph.createConstantSymIntValue(static_cast<int>(value));
listValue.push_back(symintValue);
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unknown SymIntArgument type: {}",
torch::_export::printEnum(listEl.tag())));
}
}
auto listPack =
graph.createListPack(std::move(listValue), Type::Kind::SymInt);
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
return graph.getValue(arg.get_as_custom_obj().get_name());
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
return graph.getValue(arg.get_as_sym_bool().get_as_name());
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
return graph.getValue(arg.get_as_sym_float().get_as_name());
}
default:
TORCH_CHECK(
false,
fmt::format(
"This function should only be called with symbolic arguments, got {} instead",
torch::_export::printEnum(arg.tag())));
}
}
std::pair<
std::vector<torch::_export::InputSpec>,
std::vector<torch::_export::Argument>>
enforceInputOrder(
const std::vector<torch::_export::InputSpec>& inputSpecs,
const std::vector<torch::_export::Argument>& graphInputs) {
// Enforce the order of inputSpecs and graphInputs to be the following:
// 1. token
// 2. parameter
// 3. persistent buffer, non-persistent buffer
// 4. tensor_constant
// 5. custom_obj
// 6. user_input/constant_input
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
std::vector<torch::_export::Argument> reorderedGraphInputs;
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
torch::_export::InputSpec::Tag::TOKEN,
torch::_export::InputSpec::Tag::PARAMETER,
torch::_export::InputSpec::Tag::BUFFER,
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
auto reorder = [&](auto condition) {
for (size_t i = 0; i < inputSpecs.size(); ++i) {
if (condition(inputSpecs[i])) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
};
for (const auto& tag : desiredOrder) {
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
// Add persistent buffers first, then non-persistent
reorder([&](const auto& spec) {
return spec.tag() == tag && spec.get_buffer().get_persistent();
});
reorder([&](const auto& spec) {
return spec.tag() == tag && !spec.get_buffer().get_persistent();
});
} else {
reorder([&](const auto& spec) { return spec.tag() == tag; });
}
}
// Append USER_INPUT and CONSTANT_INPUT without reordering
for (size_t i = 0; i < inputSpecs.size(); ++i) {
auto tag = inputSpecs[i].tag();
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
reorderedInputSpecs.push_back(inputSpecs[i]);
reorderedGraphInputs.push_back(graphInputs[i]);
}
}
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
}
std::unique_ptr<Graph> jsonToSubgraph(
const torch::_export::Graph& jsonGraph,
const torch::_export::GraphSignature* signature,
bool loadNodeMetadata) {
auto graphInputs = jsonGraph.get_inputs();
auto graph = Graph::createGraph();
if (signature) {
// enforcing the order signature inputspecs and graph inputs
const auto& inputSpecs = signature->get_input_specs();
auto [reorderedInputSpecs, reorderedGraphInputs] =
enforceInputOrder(inputSpecs, graphInputs);
graphInputs = std::move(reorderedGraphInputs);
auto reorderedSignature = *signature;
reorderedSignature.set_input_specs(reorderedInputSpecs);
graph->setSignature(torch::nativert::GraphSignature{reorderedSignature});
}
for (const auto& input : graphInputs) {
if (isSymbolic(input)) {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = input.get_as_tensor();
const auto& name = asTensor.get_name();
graph->addInput(name, Type::Kind::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
const auto& asCustomObj = input.get_as_custom_obj();
const std::string& name = asCustomObj.get_name();
const std::string& classFqn = asCustomObj.get_class_fqn();
graph->addInput(name, Type(Type::Kind::CustomObj, classFqn));
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported symbolic graph input type: {}",
torch::_export::printEnum(input.tag())));
}
} else {
switch (input.tag()) {
case torch::_export::Argument::Tag::AS_INT:
case torch::_export::Argument::Tag::AS_FLOAT:
case torch::_export::Argument::Tag::AS_STRING:
case torch::_export::Argument::Tag::AS_BOOL:
case torch::_export::Argument::Tag::AS_NONE: {
// Constant graph inputs are specialized in the graph, here we simply
// add a nullptr of Value to the graph input node.
graph->addInput();
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported constant graph input type: {}",
torch::_export::printEnum(input.tag())));
}
}
}
for (const auto& jsonNode : jsonGraph.get_nodes()) {
auto node = graph->insertNode(
jsonNode.get_target(),
{},
loadNodeMetadata ? jsonNode.get_metadata()
: std::unordered_map<std::string, std::string>());
std::vector<NamedArgument> args;
std::vector<Attribute> attributes;
for (const auto& input : jsonNode.get_inputs()) {
// We handle constants and symbolic inputs differently.
const auto& arg = input.get_arg();
if (isSymbolic(arg)) {
// Symbolic values are made part of the inputs to the node
node->addInput(NamedArgument{
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
node->addInput(NamedArgument{
input.get_name(),
graph->addValue(std::nullopt, Type::Kind::None, node)});
} else {
node->addAttribute(Attribute{
input.get_name(),
constantToValue(input.get_arg(), loadNodeMetadata)});
// Constant values are added as "attributes" to the node.
}
}
std::vector<Value*> outputs;
std::vector<Value*> listUnpacksToCreate;
for (const auto& output : jsonNode.get_outputs()) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_NONE: {
node->addOutput(Type::Kind::None);
break;
}
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto name = output.get_as_tensor().get_name();
node->addOutput(name, Type::Kind::Tensor);
break;
}
case torch::_export::Argument::Tag::AS_TENSORS: {
auto outputValue = node->addOutput(
graph->getUniqueValueName(), Type::Kind::TensorList);
Node* listUnpack =
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
for (const auto& arg : output.get_as_tensors()) {
listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor);
}
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto name = output.get_as_sym_int().get_as_name();
node->addOutput(name, Type::Kind::SymInt);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INTS: {
TORCH_CHECK(
false,
"SymInts NYI. We currently don't have ops that produce SymInts as output");
}
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
const auto name = output.get_as_sym_bool().get_as_name();
node->addOutput(name, Type::Kind::SymBool);
break;
}
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
TORCH_CHECK(
false,
"SymBools NYI. We currently don't have ops that produce SymBools as output");
}
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
const auto name = output.get_as_sym_float().get_as_name();
node->addOutput(name, Type::Kind::SymFloat);
break;
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
TORCH_CHECK(
false,
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported graph output type: {}",
torch::_export::printEnum(output.tag())));
}
}
}
for (const auto& output : jsonGraph.get_outputs()) {
// handle symbolic outputs and constant outputs differently
if (isSymbolic(output)) {
switch (output.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR: {
const auto& asTensor = output.get_as_tensor();
const auto& name = asTensor.get_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
case torch::_export::Argument::Tag::AS_SYM_INT: {
const auto& asSymInt = output.get_as_sym_int();
TORCH_CHECK(
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
const auto& name = asSymInt.get_as_name();
Value* outputValue = graph->getValue(name);
graph->addOutput(outputValue);
break;
}
default:
TORCH_CHECK(
false,
fmt::format(
"Unsupported graph output type: {}",
torch::_export::printEnum(output.tag())));
}
} else {
Constant constValue = constantToValue(output, loadNodeMetadata);
graph->addConstantOutput(std::move(constValue));
}
}
auto jsonTensorValue = jsonGraph.get_tensor_values();
if (!signature) {
// For subgraphs we just need to derive a graph signature that only
// contains user inputs and outputs, because we don't need to handle any
// special semantics for them, e.g. mutation or gradients.
torch::_export::GraphSignature sig;
std::vector<torch::_export::InputSpec> inputSpecs;
for (const auto& input : graph->inputs()) {
torch::_export::Argument arg;
if (input->type().kind() == Type::Kind::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{input->name()});
arg.set_as_tensor(std::move(targ));
} else {
TORCH_CHECK(
false,
fmt::format(
"Unsupported subgraph input type {}",
fmt::streamed(input->type())));
}
torch::_export::UserInputSpec userInputSpec;
userInputSpec.set_arg(std::move(arg));
torch::_export::InputSpec inputSpec;
inputSpec.set_user_input(std::move(userInputSpec));
inputSpecs.push_back(std::move(inputSpec));
}
sig.set_input_specs(std::move(inputSpecs));
std::vector<torch::_export::OutputSpec> outputSpecs;
for (const auto& output : graph->outputs()) {
torch::_export::Argument arg;
if (output->type().kind() == Type::Kind::Tensor) {
torch::_export::TensorArgument targ;
targ.set_name(std::string{output->name()});
arg.set_as_tensor(std::move(targ));
} else {
TORCH_CHECK(
false,
fmt::format(
"Unsupported subgraph output type {}",
fmt::streamed(output->type())));
}
torch::_export::UserOutputSpec userOutputSpec;
userOutputSpec.set_arg(std::move(arg));
torch::_export::OutputSpec outputSpec;
outputSpec.set_user_output(std::move(userOutputSpec));
outputSpecs.push_back(std::move(outputSpec));
}
sig.set_output_specs(std::move(outputSpecs));
graph->setSignature(torch::nativert::GraphSignature{sig});
}
// weightsTensorMeta are indexed by weight's name, not graph input's name
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
for (const auto& [inputName, weightName] :
graph->signature().inputsToWeights()) {
auto value = graph->getValue(inputName);
if (value->type().kind() == Type::Kind::CustomObj) {
// skip setting meta for non-tensor inputs
continue;
}
auto it = jsonTensorValue.find(inputName);
CHECK(it != jsonTensorValue.end())
<< "Missing tensor metadata for " << inputName
<< "in thriftGraph.tensorValue";
weightsTensorMeta[weightName] = it->second;
}
graph->setWeightsMeta(weightsTensorMeta);
graph->setTensorValuesMeta(jsonTensorValue);
graph->finalize();
graph->lint();
return graph;
}
} // namespace
bool isSymbolic(const torch::_export::Argument& arg) {
switch (arg.tag()) {
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
return true;
default:
return false;
}
}
Constant constantToValue(
const torch::_export::Argument& jsonArg,
bool loadNodeMetadata) {
switch (jsonArg.tag()) {
case torch::_export::Argument::Tag::AS_NONE:
return torch::nativert::None();
case torch::_export::Argument::Tag::AS_INT:
return jsonArg.get_as_int();
case torch::_export::Argument::Tag::AS_INTS: {
std::vector<int64_t> ret;
for (const auto& arg : jsonArg.get_as_ints()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_FLOAT:
return jsonArg.get_as_float().get();
case torch::_export::Argument::Tag::AS_FLOATS: {
std::vector<double> ret;
for (const auto& arg : jsonArg.get_as_floats()) {
ret.push_back(arg.get());
}
return ret;
}
case torch::_export::Argument::Tag::AS_STRING:
return jsonArg.get_as_string();
case torch::_export::Argument::Tag::AS_STRINGS: {
std::vector<std::string> ret;
for (const auto& arg : jsonArg.get_as_strings()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
return torch::nativert::convertJsonScalarType(
jsonArg.get_as_scalar_type());
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
return torch::nativert::convertJsonMemoryFormat(
jsonArg.get_as_memory_format());
case torch::_export::Argument::Tag::AS_LAYOUT:
return torch::nativert::convertJsonLayout(jsonArg.get_as_layout());
case torch::_export::Argument::Tag::AS_DEVICE:
return torch::nativert::convertJsonDevice(jsonArg.get_as_device());
case torch::_export::Argument::Tag::AS_BOOL:
return jsonArg.get_as_bool();
case torch::_export::Argument::Tag::AS_BOOLS: {
std::vector<bool> ret;
for (const auto& arg : jsonArg.get_as_bools()) {
ret.push_back(arg);
}
return ret;
}
case torch::_export::Argument::Tag::AS_GRAPH: {
return jsonToSubgraph(
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
}
case torch::_export::Argument::Tag::AS_TENSOR:
case torch::_export::Argument::Tag::AS_TENSORS:
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
TORCH_CHECK(false, "Tensor values are symbolic, not constant.");
case torch::_export::Argument::Tag::AS_SYM_INT:
case torch::_export::Argument::Tag::AS_SYM_INTS:
case torch::_export::Argument::Tag::AS_SYM_BOOL:
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant.");
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
TORCH_CHECK(false, "custom obj is symbolic, not constant");
case torch::_export::Argument::Tag::AS_OPERATOR:
return jsonArg.get_as_operator();
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
TORCH_CHECK(false, "SymFloat is not yet implemented");
}
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
TORCH_CHECK(false, "SymFloats is not yet implemented");
}
default:
TORCH_CHECK(false, "Got unknown json argument");
}
}
std::unique_ptr<Graph> jsonToGraph(
const torch::_export::GraphModule& jsonGraphModule,
bool loadNodeMetadata) {
auto graph = jsonToSubgraph(
jsonGraphModule.get_graph(),
&jsonGraphModule.get_signature(),
loadNodeMetadata);
return graph;
}
} // namespace torch::nativert