551 lines
20 KiB
C++
551 lines
20 KiB
C++
#include <fmt/format.h>
|
|
#include <fmt/ostream.h>
|
|
#include <fmt/ranges.h>
|
|
#include <torch/nativert/graph/Serialization.h>
|
|
#include <limits>
|
|
namespace torch::nativert {
|
|
|
|
namespace {
|
|
|
|
std::unique_ptr<Graph> jsonToSubgraph(
|
|
const torch::_export::Graph& jsonGraph,
|
|
const torch::_export::GraphSignature* signature,
|
|
bool loadNodeMetadata);
|
|
|
|
Value* symbolicToValue(
|
|
const torch::_export::Argument& arg,
|
|
Graph& graph,
|
|
Node* insertBefore) {
|
|
switch (arg.tag()) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR:
|
|
return graph.getValue(arg.get_as_tensor().get_name());
|
|
case torch::_export::Argument::Tag::AS_TENSORS: {
|
|
// Need to insert a list pack node
|
|
std::vector<Value*> listValue;
|
|
for (const auto& listEl : arg.get_as_tensors()) {
|
|
listValue.push_back(graph.getValue(listEl.get_name()));
|
|
}
|
|
auto listPack =
|
|
graph.createListPack(std::move(listValue), Type::Kind::Tensor);
|
|
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
|
}
|
|
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: {
|
|
// Need to insert a list pack node
|
|
std::vector<Value*> listValue;
|
|
for (const auto& listEl : arg.get_as_optional_tensors()) {
|
|
switch (listEl.tag()) {
|
|
case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: {
|
|
listValue.push_back(
|
|
graph.getValue(listEl.get_as_tensor().get_name()));
|
|
break;
|
|
}
|
|
case torch::_export::OptionalTensorArgument::Tag::AS_NONE: {
|
|
listValue.push_back(
|
|
graph.addValue(std::nullopt, Type::Kind::None, nullptr));
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unknown OptionalTensorArgument type: {}",
|
|
torch::_export::printEnum(listEl.tag())));
|
|
}
|
|
}
|
|
auto listPack = graph.createOptionalListPack(std::move(listValue));
|
|
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
|
return graph.getValue(arg.get_as_sym_int().get_as_name());
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
|
// Need to insert a list pack node
|
|
std::vector<Value*> listValue;
|
|
for (const auto& listEl : arg.get_as_sym_ints()) {
|
|
switch (listEl.tag()) {
|
|
case torch::_export::SymIntArgument::Tag::AS_NAME: {
|
|
listValue.push_back(graph.getValue(listEl.get_as_name()));
|
|
break;
|
|
}
|
|
case torch::_export::SymIntArgument::Tag::AS_INT: {
|
|
// These are concrete int values in the SymIntList, e.g [s0, 8]
|
|
// We convert them into a constant Value in graph. These value
|
|
// doesn't have producer node
|
|
int64_t value = listEl.get_as_int();
|
|
TORCH_CHECK(
|
|
value >= std::numeric_limits<int>::min() &&
|
|
value <= std::numeric_limits<int>::max());
|
|
Value* symintValue =
|
|
graph.createConstantSymIntValue(static_cast<int>(value));
|
|
listValue.push_back(symintValue);
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unknown SymIntArgument type: {}",
|
|
torch::_export::printEnum(listEl.tag())));
|
|
}
|
|
}
|
|
auto listPack =
|
|
graph.createListPack(std::move(listValue), Type::Kind::SymInt);
|
|
return graph.insertBefore(listPack, insertBefore)->outputs()[0];
|
|
}
|
|
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
|
return graph.getValue(arg.get_as_custom_obj().get_name());
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
|
return graph.getValue(arg.get_as_sym_bool().get_as_name());
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
|
return graph.getValue(arg.get_as_sym_float().get_as_name());
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"This function should only be called with symbolic arguments, got {} instead",
|
|
torch::_export::printEnum(arg.tag())));
|
|
}
|
|
}
|
|
|
|
std::pair<
|
|
std::vector<torch::_export::InputSpec>,
|
|
std::vector<torch::_export::Argument>>
|
|
enforceInputOrder(
|
|
const std::vector<torch::_export::InputSpec>& inputSpecs,
|
|
const std::vector<torch::_export::Argument>& graphInputs) {
|
|
// Enforce the order of inputSpecs and graphInputs to be the following:
|
|
// 1. token
|
|
// 2. parameter
|
|
// 3. persistent buffer, non-persistent buffer
|
|
// 4. tensor_constant
|
|
// 5. custom_obj
|
|
// 6. user_input/constant_input
|
|
std::vector<torch::_export::InputSpec> reorderedInputSpecs;
|
|
std::vector<torch::_export::Argument> reorderedGraphInputs;
|
|
std::vector<torch::_export::InputSpec::Tag> desiredOrder = {
|
|
torch::_export::InputSpec::Tag::TOKEN,
|
|
torch::_export::InputSpec::Tag::PARAMETER,
|
|
torch::_export::InputSpec::Tag::BUFFER,
|
|
torch::_export::InputSpec::Tag::TENSOR_CONSTANT,
|
|
torch::_export::InputSpec::Tag::CUSTOM_OBJ};
|
|
|
|
auto reorder = [&](auto condition) {
|
|
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
|
if (condition(inputSpecs[i])) {
|
|
reorderedInputSpecs.push_back(inputSpecs[i]);
|
|
reorderedGraphInputs.push_back(graphInputs[i]);
|
|
}
|
|
}
|
|
};
|
|
|
|
for (const auto& tag : desiredOrder) {
|
|
if (tag == torch::_export::InputSpec::Tag::BUFFER) {
|
|
// Add persistent buffers first, then non-persistent
|
|
reorder([&](const auto& spec) {
|
|
return spec.tag() == tag && spec.get_buffer().get_persistent();
|
|
});
|
|
reorder([&](const auto& spec) {
|
|
return spec.tag() == tag && !spec.get_buffer().get_persistent();
|
|
});
|
|
} else {
|
|
reorder([&](const auto& spec) { return spec.tag() == tag; });
|
|
}
|
|
}
|
|
|
|
// Append USER_INPUT and CONSTANT_INPUT without reordering
|
|
for (size_t i = 0; i < inputSpecs.size(); ++i) {
|
|
auto tag = inputSpecs[i].tag();
|
|
if (tag == torch::_export::InputSpec::Tag::USER_INPUT ||
|
|
tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
|
|
reorderedInputSpecs.push_back(inputSpecs[i]);
|
|
reorderedGraphInputs.push_back(graphInputs[i]);
|
|
}
|
|
}
|
|
return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)};
|
|
}
|
|
|
|
std::unique_ptr<Graph> jsonToSubgraph(
|
|
const torch::_export::Graph& jsonGraph,
|
|
const torch::_export::GraphSignature* signature,
|
|
bool loadNodeMetadata) {
|
|
auto graphInputs = jsonGraph.get_inputs();
|
|
auto graph = Graph::createGraph();
|
|
|
|
if (signature) {
|
|
// enforcing the order signature inputspecs and graph inputs
|
|
const auto& inputSpecs = signature->get_input_specs();
|
|
|
|
auto [reorderedInputSpecs, reorderedGraphInputs] =
|
|
enforceInputOrder(inputSpecs, graphInputs);
|
|
|
|
graphInputs = std::move(reorderedGraphInputs);
|
|
auto reorderedSignature = *signature;
|
|
reorderedSignature.set_input_specs(reorderedInputSpecs);
|
|
graph->setSignature(torch::nativert::GraphSignature{reorderedSignature});
|
|
}
|
|
|
|
for (const auto& input : graphInputs) {
|
|
if (isSymbolic(input)) {
|
|
switch (input.tag()) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR: {
|
|
const auto& asTensor = input.get_as_tensor();
|
|
const auto& name = asTensor.get_name();
|
|
graph->addInput(name, Type::Kind::Tensor);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: {
|
|
const auto& asCustomObj = input.get_as_custom_obj();
|
|
const std::string& name = asCustomObj.get_name();
|
|
const std::string& classFqn = asCustomObj.get_class_fqn();
|
|
graph->addInput(name, Type(Type::Kind::CustomObj, classFqn));
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported symbolic graph input type: {}",
|
|
torch::_export::printEnum(input.tag())));
|
|
}
|
|
} else {
|
|
switch (input.tag()) {
|
|
case torch::_export::Argument::Tag::AS_INT:
|
|
case torch::_export::Argument::Tag::AS_FLOAT:
|
|
case torch::_export::Argument::Tag::AS_STRING:
|
|
case torch::_export::Argument::Tag::AS_BOOL:
|
|
case torch::_export::Argument::Tag::AS_NONE: {
|
|
// Constant graph inputs are specialized in the graph, here we simply
|
|
// add a nullptr of Value to the graph input node.
|
|
graph->addInput();
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported constant graph input type: {}",
|
|
torch::_export::printEnum(input.tag())));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto& jsonNode : jsonGraph.get_nodes()) {
|
|
auto node = graph->insertNode(
|
|
jsonNode.get_target(),
|
|
{},
|
|
loadNodeMetadata ? jsonNode.get_metadata()
|
|
: std::unordered_map<std::string, std::string>());
|
|
|
|
std::vector<NamedArgument> args;
|
|
std::vector<Attribute> attributes;
|
|
for (const auto& input : jsonNode.get_inputs()) {
|
|
// We handle constants and symbolic inputs differently.
|
|
const auto& arg = input.get_arg();
|
|
if (isSymbolic(arg)) {
|
|
// Symbolic values are made part of the inputs to the node
|
|
node->addInput(NamedArgument{
|
|
input.get_name(), symbolicToValue(input.get_arg(), *graph, node)});
|
|
} else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) {
|
|
node->addInput(NamedArgument{
|
|
input.get_name(),
|
|
graph->addValue(std::nullopt, Type::Kind::None, node)});
|
|
} else {
|
|
node->addAttribute(Attribute{
|
|
input.get_name(),
|
|
constantToValue(input.get_arg(), loadNodeMetadata)});
|
|
// Constant values are added as "attributes" to the node.
|
|
}
|
|
}
|
|
|
|
std::vector<Value*> outputs;
|
|
std::vector<Value*> listUnpacksToCreate;
|
|
for (const auto& output : jsonNode.get_outputs()) {
|
|
switch (output.tag()) {
|
|
case torch::_export::Argument::Tag::AS_NONE: {
|
|
node->addOutput(Type::Kind::None);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_TENSOR: {
|
|
const auto name = output.get_as_tensor().get_name();
|
|
node->addOutput(name, Type::Kind::Tensor);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_TENSORS: {
|
|
auto outputValue = node->addOutput(
|
|
graph->getUniqueValueName(), Type::Kind::TensorList);
|
|
|
|
Node* listUnpack =
|
|
graph->insertNode("prim.ListUnpack", {{"input", outputValue}});
|
|
for (const auto& arg : output.get_as_tensors()) {
|
|
listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor);
|
|
}
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
|
const auto name = output.get_as_sym_int().get_as_name();
|
|
node->addOutput(name, Type::Kind::SymInt);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INTS: {
|
|
TORCH_CHECK(
|
|
false,
|
|
"SymInts NYI. We currently don't have ops that produce SymInts as output");
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOL: {
|
|
const auto name = output.get_as_sym_bool().get_as_name();
|
|
node->addOutput(name, Type::Kind::SymBool);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOLS: {
|
|
TORCH_CHECK(
|
|
false,
|
|
"SymBools NYI. We currently don't have ops that produce SymBools as output");
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
|
const auto name = output.get_as_sym_float().get_as_name();
|
|
node->addOutput(name, Type::Kind::SymFloat);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
|
TORCH_CHECK(
|
|
false,
|
|
"SymFloats NYI. We currently doesn't have op that produces SymFloats as output");
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported graph output type: {}",
|
|
torch::_export::printEnum(output.tag())));
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto& output : jsonGraph.get_outputs()) {
|
|
// handle symbolic outputs and constant outputs differently
|
|
if (isSymbolic(output)) {
|
|
switch (output.tag()) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR: {
|
|
const auto& asTensor = output.get_as_tensor();
|
|
const auto& name = asTensor.get_name();
|
|
Value* outputValue = graph->getValue(name);
|
|
graph->addOutput(outputValue);
|
|
break;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_INT: {
|
|
const auto& asSymInt = output.get_as_sym_int();
|
|
TORCH_CHECK(
|
|
asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME);
|
|
const auto& name = asSymInt.get_as_name();
|
|
Value* outputValue = graph->getValue(name);
|
|
graph->addOutput(outputValue);
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported graph output type: {}",
|
|
torch::_export::printEnum(output.tag())));
|
|
}
|
|
} else {
|
|
Constant constValue = constantToValue(output, loadNodeMetadata);
|
|
graph->addConstantOutput(std::move(constValue));
|
|
}
|
|
}
|
|
|
|
auto jsonTensorValue = jsonGraph.get_tensor_values();
|
|
|
|
if (!signature) {
|
|
// For subgraphs we just need to derive a graph signature that only
|
|
// contains user inputs and outputs, because we don't need to handle any
|
|
// special semantics for them, e.g. mutation or gradients.
|
|
torch::_export::GraphSignature sig;
|
|
std::vector<torch::_export::InputSpec> inputSpecs;
|
|
for (const auto& input : graph->inputs()) {
|
|
torch::_export::Argument arg;
|
|
if (input->type().kind() == Type::Kind::Tensor) {
|
|
torch::_export::TensorArgument targ;
|
|
targ.set_name(std::string{input->name()});
|
|
arg.set_as_tensor(std::move(targ));
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported subgraph input type {}",
|
|
fmt::streamed(input->type())));
|
|
}
|
|
torch::_export::UserInputSpec userInputSpec;
|
|
userInputSpec.set_arg(std::move(arg));
|
|
torch::_export::InputSpec inputSpec;
|
|
inputSpec.set_user_input(std::move(userInputSpec));
|
|
inputSpecs.push_back(std::move(inputSpec));
|
|
}
|
|
sig.set_input_specs(std::move(inputSpecs));
|
|
|
|
std::vector<torch::_export::OutputSpec> outputSpecs;
|
|
for (const auto& output : graph->outputs()) {
|
|
torch::_export::Argument arg;
|
|
if (output->type().kind() == Type::Kind::Tensor) {
|
|
torch::_export::TensorArgument targ;
|
|
targ.set_name(std::string{output->name()});
|
|
arg.set_as_tensor(std::move(targ));
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Unsupported subgraph output type {}",
|
|
fmt::streamed(output->type())));
|
|
}
|
|
torch::_export::UserOutputSpec userOutputSpec;
|
|
userOutputSpec.set_arg(std::move(arg));
|
|
torch::_export::OutputSpec outputSpec;
|
|
outputSpec.set_user_output(std::move(userOutputSpec));
|
|
outputSpecs.push_back(std::move(outputSpec));
|
|
}
|
|
sig.set_output_specs(std::move(outputSpecs));
|
|
|
|
graph->setSignature(torch::nativert::GraphSignature{sig});
|
|
}
|
|
|
|
// weightsTensorMeta are indexed by weight's name, not graph input's name
|
|
std::unordered_map<std::string, torch::_export::TensorMeta> weightsTensorMeta;
|
|
for (const auto& [inputName, weightName] :
|
|
graph->signature().inputsToWeights()) {
|
|
auto value = graph->getValue(inputName);
|
|
if (value->type().kind() == Type::Kind::CustomObj) {
|
|
// skip setting meta for non-tensor inputs
|
|
continue;
|
|
}
|
|
|
|
auto it = jsonTensorValue.find(inputName);
|
|
CHECK(it != jsonTensorValue.end())
|
|
<< "Missing tensor metadata for " << inputName
|
|
<< "in thriftGraph.tensorValue";
|
|
weightsTensorMeta[weightName] = it->second;
|
|
}
|
|
graph->setWeightsMeta(weightsTensorMeta);
|
|
|
|
graph->setTensorValuesMeta(jsonTensorValue);
|
|
|
|
graph->finalize();
|
|
|
|
graph->lint();
|
|
return graph;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool isSymbolic(const torch::_export::Argument& arg) {
|
|
switch (arg.tag()) {
|
|
case torch::_export::Argument::Tag::AS_TENSOR:
|
|
case torch::_export::Argument::Tag::AS_TENSORS:
|
|
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
|
case torch::_export::Argument::Tag::AS_SYM_INT:
|
|
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOAT:
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOATS:
|
|
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
Constant constantToValue(
|
|
const torch::_export::Argument& jsonArg,
|
|
bool loadNodeMetadata) {
|
|
switch (jsonArg.tag()) {
|
|
case torch::_export::Argument::Tag::AS_NONE:
|
|
return torch::nativert::None();
|
|
case torch::_export::Argument::Tag::AS_INT:
|
|
return jsonArg.get_as_int();
|
|
case torch::_export::Argument::Tag::AS_INTS: {
|
|
std::vector<int64_t> ret;
|
|
for (const auto& arg : jsonArg.get_as_ints()) {
|
|
ret.push_back(arg);
|
|
}
|
|
return ret;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_FLOAT:
|
|
return jsonArg.get_as_float().get();
|
|
case torch::_export::Argument::Tag::AS_FLOATS: {
|
|
std::vector<double> ret;
|
|
for (const auto& arg : jsonArg.get_as_floats()) {
|
|
ret.push_back(arg.get());
|
|
}
|
|
return ret;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_STRING:
|
|
return jsonArg.get_as_string();
|
|
case torch::_export::Argument::Tag::AS_STRINGS: {
|
|
std::vector<std::string> ret;
|
|
for (const auto& arg : jsonArg.get_as_strings()) {
|
|
ret.push_back(arg);
|
|
}
|
|
return ret;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SCALAR_TYPE:
|
|
return torch::nativert::convertJsonScalarType(
|
|
jsonArg.get_as_scalar_type());
|
|
case torch::_export::Argument::Tag::AS_MEMORY_FORMAT:
|
|
return torch::nativert::convertJsonMemoryFormat(
|
|
jsonArg.get_as_memory_format());
|
|
case torch::_export::Argument::Tag::AS_LAYOUT:
|
|
return torch::nativert::convertJsonLayout(jsonArg.get_as_layout());
|
|
case torch::_export::Argument::Tag::AS_DEVICE:
|
|
return torch::nativert::convertJsonDevice(jsonArg.get_as_device());
|
|
case torch::_export::Argument::Tag::AS_BOOL:
|
|
return jsonArg.get_as_bool();
|
|
case torch::_export::Argument::Tag::AS_BOOLS: {
|
|
std::vector<bool> ret;
|
|
for (const auto& arg : jsonArg.get_as_bools()) {
|
|
ret.push_back(arg);
|
|
}
|
|
return ret;
|
|
}
|
|
case torch::_export::Argument::Tag::AS_GRAPH: {
|
|
return jsonToSubgraph(
|
|
*jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata);
|
|
}
|
|
case torch::_export::Argument::Tag::AS_TENSOR:
|
|
case torch::_export::Argument::Tag::AS_TENSORS:
|
|
case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
|
|
TORCH_CHECK(false, "Tensor values are symbolic, not constant.");
|
|
case torch::_export::Argument::Tag::AS_SYM_INT:
|
|
case torch::_export::Argument::Tag::AS_SYM_INTS:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOL:
|
|
case torch::_export::Argument::Tag::AS_SYM_BOOLS:
|
|
TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant.");
|
|
case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
|
|
TORCH_CHECK(false, "custom obj is symbolic, not constant");
|
|
case torch::_export::Argument::Tag::AS_OPERATOR:
|
|
return jsonArg.get_as_operator();
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOAT: {
|
|
TORCH_CHECK(false, "SymFloat is not yet implemented");
|
|
}
|
|
case torch::_export::Argument::Tag::AS_SYM_FLOATS: {
|
|
TORCH_CHECK(false, "SymFloats is not yet implemented");
|
|
}
|
|
default:
|
|
TORCH_CHECK(false, "Got unknown json argument");
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Graph> jsonToGraph(
|
|
const torch::_export::GraphModule& jsonGraphModule,
|
|
bool loadNodeMetadata) {
|
|
auto graph = jsonToSubgraph(
|
|
jsonGraphModule.get_graph(),
|
|
&jsonGraphModule.get_signature(),
|
|
loadNodeMetadata);
|
|
return graph;
|
|
}
|
|
|
|
} // namespace torch::nativert
|