#include #include #include #include #include namespace torch::nativert { namespace { std::unique_ptr jsonToSubgraph( const torch::_export::Graph& jsonGraph, const torch::_export::GraphSignature* signature, bool loadNodeMetadata); Value* symbolicToValue( const torch::_export::Argument& arg, Graph& graph, Node* insertBefore) { switch (arg.tag()) { case torch::_export::Argument::Tag::AS_TENSOR: return graph.getValue(arg.get_as_tensor().get_name()); case torch::_export::Argument::Tag::AS_TENSORS: { // Need to insert a list pack node std::vector listValue; for (const auto& listEl : arg.get_as_tensors()) { listValue.push_back(graph.getValue(listEl.get_name())); } auto listPack = graph.createListPack(std::move(listValue), Type::Kind::Tensor); return graph.insertBefore(listPack, insertBefore)->outputs()[0]; } case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: { // Need to insert a list pack node std::vector listValue; for (const auto& listEl : arg.get_as_optional_tensors()) { switch (listEl.tag()) { case torch::_export::OptionalTensorArgument::Tag::AS_TENSOR: { listValue.push_back( graph.getValue(listEl.get_as_tensor().get_name())); break; } case torch::_export::OptionalTensorArgument::Tag::AS_NONE: { listValue.push_back( graph.addValue(std::nullopt, Type::Kind::None, nullptr)); break; } default: TORCH_CHECK( false, fmt::format( "Unknown OptionalTensorArgument type: {}", torch::_export::printEnum(listEl.tag()))); } } auto listPack = graph.createOptionalListPack(std::move(listValue)); return graph.insertBefore(listPack, insertBefore)->outputs()[0]; } case torch::_export::Argument::Tag::AS_SYM_INT: { return graph.getValue(arg.get_as_sym_int().get_as_name()); } case torch::_export::Argument::Tag::AS_SYM_INTS: { // Need to insert a list pack node std::vector listValue; for (const auto& listEl : arg.get_as_sym_ints()) { switch (listEl.tag()) { case torch::_export::SymIntArgument::Tag::AS_NAME: { listValue.push_back(graph.getValue(listEl.get_as_name())); break; } case torch::_export::SymIntArgument::Tag::AS_INT: { // These are concrete int values in the SymIntList, e.g [s0, 8] // We convert them into a constant Value in graph. These value // doesn't have producer node int64_t value = listEl.get_as_int(); TORCH_CHECK( value >= std::numeric_limits::min() && value <= std::numeric_limits::max()); Value* symintValue = graph.createConstantSymIntValue(static_cast(value)); listValue.push_back(symintValue); break; } default: TORCH_CHECK( false, fmt::format( "Unknown SymIntArgument type: {}", torch::_export::printEnum(listEl.tag()))); } } auto listPack = graph.createListPack(std::move(listValue), Type::Kind::SymInt); return graph.insertBefore(listPack, insertBefore)->outputs()[0]; } case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: { return graph.getValue(arg.get_as_custom_obj().get_name()); } case torch::_export::Argument::Tag::AS_SYM_BOOL: { return graph.getValue(arg.get_as_sym_bool().get_as_name()); } case torch::_export::Argument::Tag::AS_SYM_FLOAT: { return graph.getValue(arg.get_as_sym_float().get_as_name()); } default: TORCH_CHECK( false, fmt::format( "This function should only be called with symbolic arguments, got {} instead", torch::_export::printEnum(arg.tag()))); } } std::pair< std::vector, std::vector> enforceInputOrder( const std::vector& inputSpecs, const std::vector& graphInputs) { // Enforce the order of inputSpecs and graphInputs to be the following: // 1. token // 2. parameter // 3. persistent buffer, non-persistent buffer // 4. tensor_constant // 5. custom_obj // 6. user_input/constant_input std::vector reorderedInputSpecs; std::vector reorderedGraphInputs; std::vector desiredOrder = { torch::_export::InputSpec::Tag::TOKEN, torch::_export::InputSpec::Tag::PARAMETER, torch::_export::InputSpec::Tag::BUFFER, torch::_export::InputSpec::Tag::TENSOR_CONSTANT, torch::_export::InputSpec::Tag::CUSTOM_OBJ}; auto reorder = [&](auto condition) { for (size_t i = 0; i < inputSpecs.size(); ++i) { if (condition(inputSpecs[i])) { reorderedInputSpecs.push_back(inputSpecs[i]); reorderedGraphInputs.push_back(graphInputs[i]); } } }; for (const auto& tag : desiredOrder) { if (tag == torch::_export::InputSpec::Tag::BUFFER) { // Add persistent buffers first, then non-persistent reorder([&](const auto& spec) { return spec.tag() == tag && spec.get_buffer().get_persistent(); }); reorder([&](const auto& spec) { return spec.tag() == tag && !spec.get_buffer().get_persistent(); }); } else { reorder([&](const auto& spec) { return spec.tag() == tag; }); } } // Append USER_INPUT and CONSTANT_INPUT without reordering for (size_t i = 0; i < inputSpecs.size(); ++i) { auto tag = inputSpecs[i].tag(); if (tag == torch::_export::InputSpec::Tag::USER_INPUT || tag == torch::_export::InputSpec::Tag::CONSTANT_INPUT) { reorderedInputSpecs.push_back(inputSpecs[i]); reorderedGraphInputs.push_back(graphInputs[i]); } } return {std::move(reorderedInputSpecs), std::move(reorderedGraphInputs)}; } std::unique_ptr jsonToSubgraph( const torch::_export::Graph& jsonGraph, const torch::_export::GraphSignature* signature, bool loadNodeMetadata) { auto graphInputs = jsonGraph.get_inputs(); auto graph = Graph::createGraph(); if (signature) { // enforcing the order signature inputspecs and graph inputs const auto& inputSpecs = signature->get_input_specs(); auto [reorderedInputSpecs, reorderedGraphInputs] = enforceInputOrder(inputSpecs, graphInputs); graphInputs = std::move(reorderedGraphInputs); auto reorderedSignature = *signature; reorderedSignature.set_input_specs(reorderedInputSpecs); graph->setSignature(torch::nativert::GraphSignature{reorderedSignature}); } for (const auto& input : graphInputs) { if (isSymbolic(input)) { switch (input.tag()) { case torch::_export::Argument::Tag::AS_TENSOR: { const auto& asTensor = input.get_as_tensor(); const auto& name = asTensor.get_name(); graph->addInput(name, Type::Kind::Tensor); break; } case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: { const auto& asCustomObj = input.get_as_custom_obj(); const std::string& name = asCustomObj.get_name(); const std::string& classFqn = asCustomObj.get_class_fqn(); graph->addInput(name, Type(Type::Kind::CustomObj, classFqn)); break; } default: TORCH_CHECK( false, fmt::format( "Unsupported symbolic graph input type: {}", torch::_export::printEnum(input.tag()))); } } else { switch (input.tag()) { case torch::_export::Argument::Tag::AS_INT: case torch::_export::Argument::Tag::AS_FLOAT: case torch::_export::Argument::Tag::AS_STRING: case torch::_export::Argument::Tag::AS_BOOL: case torch::_export::Argument::Tag::AS_NONE: { // Constant graph inputs are specialized in the graph, here we simply // add a nullptr of Value to the graph input node. graph->addInput(); break; } default: TORCH_CHECK( false, fmt::format( "Unsupported constant graph input type: {}", torch::_export::printEnum(input.tag()))); } } } for (const auto& jsonNode : jsonGraph.get_nodes()) { auto node = graph->insertNode( jsonNode.get_target(), {}, loadNodeMetadata ? jsonNode.get_metadata() : std::unordered_map()); std::vector args; std::vector attributes; for (const auto& input : jsonNode.get_inputs()) { // We handle constants and symbolic inputs differently. const auto& arg = input.get_arg(); if (isSymbolic(arg)) { // Symbolic values are made part of the inputs to the node node->addInput(NamedArgument{ input.get_name(), symbolicToValue(input.get_arg(), *graph, node)}); } else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) { node->addInput(NamedArgument{ input.get_name(), graph->addValue(std::nullopt, Type::Kind::None, node)}); } else { node->addAttribute(Attribute{ input.get_name(), constantToValue(input.get_arg(), loadNodeMetadata)}); // Constant values are added as "attributes" to the node. } } std::vector outputs; std::vector listUnpacksToCreate; for (const auto& output : jsonNode.get_outputs()) { switch (output.tag()) { case torch::_export::Argument::Tag::AS_NONE: { node->addOutput(Type::Kind::None); break; } case torch::_export::Argument::Tag::AS_TENSOR: { const auto name = output.get_as_tensor().get_name(); node->addOutput(name, Type::Kind::Tensor); break; } case torch::_export::Argument::Tag::AS_TENSORS: { auto outputValue = node->addOutput( graph->getUniqueValueName(), Type::Kind::TensorList); Node* listUnpack = graph->insertNode("prim.ListUnpack", {{"input", outputValue}}); for (const auto& arg : output.get_as_tensors()) { listUnpack->addOutput(arg.get_name(), Type::Kind::Tensor); } break; } case torch::_export::Argument::Tag::AS_SYM_INT: { const auto name = output.get_as_sym_int().get_as_name(); node->addOutput(name, Type::Kind::SymInt); break; } case torch::_export::Argument::Tag::AS_SYM_INTS: { TORCH_CHECK( false, "SymInts NYI. We currently don't have ops that produce SymInts as output"); } case torch::_export::Argument::Tag::AS_SYM_BOOL: { const auto name = output.get_as_sym_bool().get_as_name(); node->addOutput(name, Type::Kind::SymBool); break; } case torch::_export::Argument::Tag::AS_SYM_BOOLS: { TORCH_CHECK( false, "SymBools NYI. We currently don't have ops that produce SymBools as output"); } case torch::_export::Argument::Tag::AS_SYM_FLOAT: { const auto name = output.get_as_sym_float().get_as_name(); node->addOutput(name, Type::Kind::SymFloat); break; } case torch::_export::Argument::Tag::AS_SYM_FLOATS: { TORCH_CHECK( false, "SymFloats NYI. We currently doesn't have op that produces SymFloats as output"); } default: TORCH_CHECK( false, fmt::format( "Unsupported graph output type: {}", torch::_export::printEnum(output.tag()))); } } } for (const auto& output : jsonGraph.get_outputs()) { // handle symbolic outputs and constant outputs differently if (isSymbolic(output)) { switch (output.tag()) { case torch::_export::Argument::Tag::AS_TENSOR: { const auto& asTensor = output.get_as_tensor(); const auto& name = asTensor.get_name(); Value* outputValue = graph->getValue(name); graph->addOutput(outputValue); break; } case torch::_export::Argument::Tag::AS_SYM_INT: { const auto& asSymInt = output.get_as_sym_int(); TORCH_CHECK( asSymInt.tag() == torch::_export::SymIntArgument::Tag::AS_NAME); const auto& name = asSymInt.get_as_name(); Value* outputValue = graph->getValue(name); graph->addOutput(outputValue); break; } default: TORCH_CHECK( false, fmt::format( "Unsupported graph output type: {}", torch::_export::printEnum(output.tag()))); } } else { Constant constValue = constantToValue(output, loadNodeMetadata); graph->addConstantOutput(std::move(constValue)); } } auto jsonTensorValue = jsonGraph.get_tensor_values(); if (!signature) { // For subgraphs we just need to derive a graph signature that only // contains user inputs and outputs, because we don't need to handle any // special semantics for them, e.g. mutation or gradients. torch::_export::GraphSignature sig; std::vector inputSpecs; for (const auto& input : graph->inputs()) { torch::_export::Argument arg; if (input->type().kind() == Type::Kind::Tensor) { torch::_export::TensorArgument targ; targ.set_name(std::string{input->name()}); arg.set_as_tensor(std::move(targ)); } else { TORCH_CHECK( false, fmt::format( "Unsupported subgraph input type {}", fmt::streamed(input->type()))); } torch::_export::UserInputSpec userInputSpec; userInputSpec.set_arg(std::move(arg)); torch::_export::InputSpec inputSpec; inputSpec.set_user_input(std::move(userInputSpec)); inputSpecs.push_back(std::move(inputSpec)); } sig.set_input_specs(std::move(inputSpecs)); std::vector outputSpecs; for (const auto& output : graph->outputs()) { torch::_export::Argument arg; if (output->type().kind() == Type::Kind::Tensor) { torch::_export::TensorArgument targ; targ.set_name(std::string{output->name()}); arg.set_as_tensor(std::move(targ)); } else { TORCH_CHECK( false, fmt::format( "Unsupported subgraph output type {}", fmt::streamed(output->type()))); } torch::_export::UserOutputSpec userOutputSpec; userOutputSpec.set_arg(std::move(arg)); torch::_export::OutputSpec outputSpec; outputSpec.set_user_output(std::move(userOutputSpec)); outputSpecs.push_back(std::move(outputSpec)); } sig.set_output_specs(std::move(outputSpecs)); graph->setSignature(torch::nativert::GraphSignature{sig}); } // weightsTensorMeta are indexed by weight's name, not graph input's name std::unordered_map weightsTensorMeta; for (const auto& [inputName, weightName] : graph->signature().inputsToWeights()) { auto value = graph->getValue(inputName); if (value->type().kind() == Type::Kind::CustomObj) { // skip setting meta for non-tensor inputs continue; } auto it = jsonTensorValue.find(inputName); CHECK(it != jsonTensorValue.end()) << "Missing tensor metadata for " << inputName << "in thriftGraph.tensorValue"; weightsTensorMeta[weightName] = it->second; } graph->setWeightsMeta(weightsTensorMeta); graph->setTensorValuesMeta(jsonTensorValue); graph->finalize(); graph->lint(); return graph; } } // namespace bool isSymbolic(const torch::_export::Argument& arg) { switch (arg.tag()) { case torch::_export::Argument::Tag::AS_TENSOR: case torch::_export::Argument::Tag::AS_TENSORS: case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: case torch::_export::Argument::Tag::AS_SYM_INT: case torch::_export::Argument::Tag::AS_SYM_INTS: case torch::_export::Argument::Tag::AS_SYM_BOOL: case torch::_export::Argument::Tag::AS_SYM_BOOLS: case torch::_export::Argument::Tag::AS_SYM_FLOAT: case torch::_export::Argument::Tag::AS_SYM_FLOATS: case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: return true; default: return false; } } Constant constantToValue( const torch::_export::Argument& jsonArg, bool loadNodeMetadata) { switch (jsonArg.tag()) { case torch::_export::Argument::Tag::AS_NONE: return torch::nativert::None(); case torch::_export::Argument::Tag::AS_INT: return jsonArg.get_as_int(); case torch::_export::Argument::Tag::AS_INTS: { std::vector ret; for (const auto& arg : jsonArg.get_as_ints()) { ret.push_back(arg); } return ret; } case torch::_export::Argument::Tag::AS_FLOAT: return jsonArg.get_as_float().get(); case torch::_export::Argument::Tag::AS_FLOATS: { std::vector ret; for (const auto& arg : jsonArg.get_as_floats()) { ret.push_back(arg.get()); } return ret; } case torch::_export::Argument::Tag::AS_STRING: return jsonArg.get_as_string(); case torch::_export::Argument::Tag::AS_STRINGS: { std::vector ret; for (const auto& arg : jsonArg.get_as_strings()) { ret.push_back(arg); } return ret; } case torch::_export::Argument::Tag::AS_SCALAR_TYPE: return torch::nativert::convertJsonScalarType( jsonArg.get_as_scalar_type()); case torch::_export::Argument::Tag::AS_MEMORY_FORMAT: return torch::nativert::convertJsonMemoryFormat( jsonArg.get_as_memory_format()); case torch::_export::Argument::Tag::AS_LAYOUT: return torch::nativert::convertJsonLayout(jsonArg.get_as_layout()); case torch::_export::Argument::Tag::AS_DEVICE: return torch::nativert::convertJsonDevice(jsonArg.get_as_device()); case torch::_export::Argument::Tag::AS_BOOL: return jsonArg.get_as_bool(); case torch::_export::Argument::Tag::AS_BOOLS: { std::vector ret; for (const auto& arg : jsonArg.get_as_bools()) { ret.push_back(arg); } return ret; } case torch::_export::Argument::Tag::AS_GRAPH: { return jsonToSubgraph( *jsonArg.get_as_graph().get_graph(), nullptr, loadNodeMetadata); } case torch::_export::Argument::Tag::AS_TENSOR: case torch::_export::Argument::Tag::AS_TENSORS: case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: TORCH_CHECK(false, "Tensor values are symbolic, not constant."); case torch::_export::Argument::Tag::AS_SYM_INT: case torch::_export::Argument::Tag::AS_SYM_INTS: case torch::_export::Argument::Tag::AS_SYM_BOOL: case torch::_export::Argument::Tag::AS_SYM_BOOLS: TORCH_CHECK(false, "Symint/Symbool Values are symbolic, not constant."); case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: TORCH_CHECK(false, "custom obj is symbolic, not constant"); case torch::_export::Argument::Tag::AS_OPERATOR: return jsonArg.get_as_operator(); case torch::_export::Argument::Tag::AS_SYM_FLOAT: { TORCH_CHECK(false, "SymFloat is not yet implemented"); } case torch::_export::Argument::Tag::AS_SYM_FLOATS: { TORCH_CHECK(false, "SymFloats is not yet implemented"); } default: TORCH_CHECK(false, "Got unknown json argument"); } } std::unique_ptr jsonToGraph( const torch::_export::GraphModule& jsonGraphModule, bool loadNodeMetadata) { auto graph = jsonToSubgraph( jsonGraphModule.get_graph(), &jsonGraphModule.get_signature(), loadNodeMetadata); return graph; } } // namespace torch::nativert