#include #include #include #include #include #include #include #include #include #include namespace torch::nativert { namespace { bool isSymbolicOutput(torch::_export::Argument::Tag t) { switch (t) { case torch::_export::Argument::Tag::AS_TENSOR: case torch::_export::Argument::Tag::AS_TENSORS: case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS: case torch::_export::Argument::Tag::AS_SYM_BOOL: case torch::_export::Argument::Tag::AS_SYM_BOOLS: case torch::_export::Argument::Tag::AS_SYM_INT: case torch::_export::Argument::Tag::AS_SYM_INTS: case torch::_export::Argument::Tag::AS_SYM_FLOAT: case torch::_export::Argument::Tag::AS_SYM_FLOATS: case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: return true; default: return false; } } std::pair getSpecDetails( const torch::_export::InputSpec& inputSpec) { // Retrieve the argument name and spec tag name switch (inputSpec.tag()) { case torch::_export::InputSpec::Tag::PARAMETER: return std::make_pair( inputSpec.get_parameter().get_arg().get_name(), "PARAMETER"); break; case torch::_export::InputSpec::Tag::BUFFER: return std::make_pair( inputSpec.get_buffer().get_arg().get_name(), "BUFFER"); break; case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: return std::make_pair( inputSpec.get_tensor_constant().get_arg().get_name(), "TENSOR_CONSTANT"); break; case torch::_export::InputSpec::Tag::CUSTOM_OBJ: return std::make_pair( inputSpec.get_custom_obj().get_arg().get_name(), "CUSTOM_OBJ"); break; case torch::_export::InputSpec::Tag::USER_INPUT: if (inputSpec.get_user_input().get_arg().tag() == torch::_export::Argument::Tag::AS_TENSOR) { return std::make_pair( inputSpec.get_user_input().get_arg().get_as_tensor().get_name(), "USER_INPUT"); } else if ( inputSpec.get_user_input().get_arg().tag() == torch::_export::Argument::Tag::AS_CUSTOM_OBJ) { return std::make_pair( inputSpec.get_user_input().get_arg().get_as_custom_obj().get_name(), "USER_INPUT"); } else { TORCH_CHECK(false, "Unsupported USER_INPUT argument type."); } break; case torch::_export::InputSpec::Tag::CONSTANT_INPUT: return std::make_pair( inputSpec.get_constant_input().get_name(), "CONSTANT_INPUT"); break; case torch::_export::InputSpec::Tag::TOKEN: TORCH_CHECK(false, "Token inputs not implemented yet."); default: TORCH_CHECK(false, "Unknown InputSpec tag encountered."); } } void checkInputOrders( const std::vector& inputSpecs) { // Map each tag to its index in the expected order static constexpr std:: array, 5> tagOrderArray = { {{torch::_export::InputSpec::Tag::TOKEN, 0}, {torch::_export::InputSpec::Tag::PARAMETER, 1}, {torch::_export::InputSpec::Tag::BUFFER, 2}, {torch::_export::InputSpec::Tag::TENSOR_CONSTANT, 3}, {torch::_export::InputSpec::Tag::CUSTOM_OBJ, 4}}}; uint32_t currentOrderIndex = 0; bool seenNonPersistentBuffer = false; for (const auto& inputSpec : inputSpecs) { if (inputSpec.tag() == torch::_export::InputSpec::Tag::USER_INPUT || inputSpec.tag() == torch::_export::InputSpec::Tag::CONSTANT_INPUT) { continue; } auto it = std::find_if( tagOrderArray.begin(), tagOrderArray.end(), [&inputSpec](const auto& pair) { return pair.first == inputSpec.tag(); }); TORCH_CHECK( it != tagOrderArray.end(), "Unknown InputSpec tag encountered."); uint32_t tagIndex = it->second; if (tagIndex < currentOrderIndex) { auto [argName, tagName] = getSpecDetails(inputSpec); TORCH_CHECK( false, fmt::format( "Input arg {} with InputSpec {} is out of order!", argName, tagName)); } currentOrderIndex = tagIndex; // Additional check for buffers if (inputSpec.tag() == torch::_export::InputSpec::Tag::BUFFER) { if (!inputSpec.get_buffer().get_persistent()) { seenNonPersistentBuffer = true; } else { TORCH_CHECK( !seenNonPersistentBuffer, "Persistent buffers must come before non-persistent buffers."); } } } } void checkInputNames( const c10::FastSet& sigNames, const c10::FastSet& graphNames) { if (sigNames == graphNames) { return; } std::string errorMsg = fmt::format( "Error: Value name difference detected between graph signature and graph nodes:\n" "Signature value names:\n[{}]\n" "Graph node names:\n[{}]", fmt::join(sigNames, ", "), fmt::join(graphNames, ", ")); TORCH_CHECK(false, errorMsg); } void checkOutputNames( const c10::FastSet>& sigNames, const c10::FastSet& graphNames) { std::vector validNames; for (const auto& nameOpt : sigNames) { if (nameOpt.has_value()) { validNames.push_back(*nameOpt); } } for (const auto& name : validNames) { if (graphNames.find(name) == graphNames.end()) { std::string errorMsg = fmt::format( "Error: Value name difference detected between graph signature and graph nodes:\n" "Signature value names:\n[{}]\n" "Graph node names:\n[{}]", fmt::join(validNames, ", "), fmt::join(graphNames, ", ")); TORCH_CHECK(false, errorMsg); } } } void replaceInMap( c10::FastMap& map, std::string_view old, std::string_view replacement) { auto it = map.find(std::string{old}); if (it == map.end()) { return; } std::string value = std::move(it->second); map.erase(it); map.emplace(replacement, std::move(value)); } } // namespace GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) { checkInputOrders(storage.get_input_specs()); for (const torch::_export::InputSpec& inputSpec : storage.get_input_specs()) { switch (inputSpec.tag()) { case torch::_export::InputSpec::Tag::USER_INPUT: { const auto& userInputArg = inputSpec.get_user_input().get_arg(); if (userInputArg.tag() == torch::_export::Argument::Tag::AS_TENSOR) { userInputs_.emplace_back(userInputArg.get_as_tensor().get_name()); } else if ( userInputArg.tag() == torch::_export::Argument::Tag::AS_CUSTOM_OBJ) { userInputs_.emplace_back(userInputArg.get_as_custom_obj().get_name()); } else { // TODO: handle other types TORCH_CHECK(false, "Non tensor inputs not implemented yet."); } break; } case torch::_export::InputSpec::Tag::PARAMETER: { numParameters_++; const auto& inputName = inputSpec.get_parameter().get_arg().get_name(); const auto& weightName = inputSpec.get_parameter().get_parameter_name(); inputsToWeights_.emplace_back(inputName, weightName); break; } case torch::_export::InputSpec::Tag::BUFFER: { const bool isPersistent = inputSpec.get_buffer().get_persistent(); const auto& inputName = inputSpec.get_buffer().get_arg().get_name(); const auto& weightName = inputSpec.get_buffer().get_buffer_name(); if (isPersistent) { numPersistentBuffers_++; } else { numNonPersistentBuffers_++; } inputsToWeights_.emplace_back(inputName, weightName); break; } case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: { numTensorConstants_++; const auto& inputName = inputSpec.get_tensor_constant().get_arg().get_name(); const auto& weightName = inputSpec.get_tensor_constant().get_tensor_constant_name(); inputsToWeights_.emplace_back(inputName, weightName); break; } case torch::_export::InputSpec::Tag::CUSTOM_OBJ: { numCustomObjs_++; const auto& inputName = inputSpec.get_custom_obj().get_arg().get_name(); const auto& customObjName = inputSpec.get_custom_obj().get_custom_obj_name(); inputsToCustomObjs_.emplace_back(inputName, customObjName); break; } case torch::_export::InputSpec::Tag::CONSTANT_INPUT: { break; } case torch::_export::InputSpec::Tag::TOKEN: { TORCH_CHECK(false, "Token inputs not implemented yet."); } default: TORCH_CHECK(false, "Unknown InputSpec tag encountered."); break; } } for (const torch::_export::OutputSpec& outputSpec : storage.get_output_specs()) { switch (outputSpec.tag()) { case torch::_export::OutputSpec::Tag::LOSS_OUTPUT: lossOutput_ = outputSpec.get_loss_output().get_arg().get_name(); break; case torch::_export::OutputSpec::Tag::USER_OUTPUT: if (isSymbolicOutput(outputSpec.get_user_output().get_arg().tag())) { switch (outputSpec.get_user_output().get_arg().tag()) { case torch::_export::Argument::Tag::AS_TENSOR: { userOutputs_.emplace_back(outputSpec.get_user_output() .get_arg() .get_as_tensor() .get_name()); break; } case torch::_export::Argument::Tag::AS_SYM_INT: { userOutputs_.emplace_back(outputSpec.get_user_output() .get_arg() .get_as_sym_int() .get_as_name()); break; } default: { TORCH_CHECK( false, "Unsupported symbolic user output type encountered."); } } } else { // for constant outputs, we don't have a name userOutputs_.emplace_back(std::nullopt); } break; case torch::_export::OutputSpec::Tag::BUFFER_MUTATION: buffersToMutate_.emplace( outputSpec.get_buffer_mutation().get_arg().get_name(), outputSpec.get_buffer_mutation().get_buffer_name()); break; case torch::_export::OutputSpec::Tag::GRADIENT_TO_PARAMETER: gradientsToParameters_.emplace( outputSpec.get_gradient_to_parameter().get_arg().get_name(), outputSpec.get_gradient_to_parameter().get_parameter_name()); break; case torch::_export::OutputSpec::Tag::GRADIENT_TO_USER_INPUT: gradientsToUserInputs_.emplace( outputSpec.get_gradient_to_user_input().get_arg().get_name(), outputSpec.get_gradient_to_user_input().get_user_input_name()); break; case torch::_export::OutputSpec::Tag::USER_INPUT_MUTATION: userInputsToMutate_.emplace( outputSpec.get_user_input_mutation().get_arg().get_name(), outputSpec.get_user_input_mutation().get_user_input_name()); break; case torch::_export::OutputSpec::Tag::TOKEN: { TORCH_CHECK(false, "Token outputs not implemented yet."); } default: TORCH_CHECK(false, "Unknown OutputSpec tag encountered."); } } if (FLAGS_caffe2_log_level > 2) { std::cout << *this << "\n"; } } c10::FastSet GraphSignature::inputNames() const { c10::FastSet ret; size_t numInputs = userInputs().size() + inputsToWeights().size() + inputsToCustomObjs().size(); ret.reserve(numInputs); for (const auto& name : userInputs()) { ret.insert(name); } for (const auto& [inputName, _] : inputsToWeights()) { ret.insert(inputName); } for (const auto& [inputName, _] : inputsToCustomObjs()) { ret.insert(inputName); } return ret; } c10::FastSet> GraphSignature::outputNames() const { c10::FastSet> ret; size_t numOutputs = userOutputs().size() + buffersToMutate().size() + userInputsToMutate().size() + (hasBackward() ? gradientsToParameters().size() + gradientsToUserInputs().size() + (lossOutput().empty() ? 0 : 1) : 0); ret.reserve(numOutputs); for (const auto& name : userOutputs()) { ret.insert(name); } for (const auto& [outputName, _] : buffersToMutate()) { ret.insert(outputName); } for (const auto& [outputName, _] : userInputsToMutate()) { ret.insert(outputName); } if (hasBackward()) { if (!gradientsToParameters().empty()) { for (const auto& [outputName, _] : gradientsToParameters()) { ret.insert(outputName); } } if (!gradientsToUserInputs().empty()) { for (const auto& [outputName, _] : gradientsToUserInputs()) { ret.insert(outputName); } } if (!lossOutput().empty()) { ret.insert(lossOutput()); } } return ret; } void GraphSignature::lint( const c10::FastSet& graphInputs, const c10::FastSet& graphOutputs) const { checkInputNames(inputNames(), graphInputs); checkOutputNames(outputNames(), graphOutputs); } void GraphSignature::replaceAllUses( std::string_view old, std::string_view replacement) { if (old == replacement) { return; } for (auto& name : userOutputs_) { if (name == old) { name = replacement; } } replaceInMap(buffersToMutate_, old, replacement); if (hasBackward()) { replaceInMap(gradientsToParameters_, old, replacement); replaceInMap(gradientsToUserInputs_, old, replacement); if (old == lossOutput_) { lossOutput_ = replacement; } } } std::ostream& operator<<(std::ostream& out, const GraphSignature& sig) { if (!sig.inputsToParameters().empty()) { out << "inputsToParameters: {\n"; for (const auto& [inputName, paramName] : sig.inputsToParameters()) { out << "\t" << inputName << " : " << paramName << "\n"; } out << "}\n"; } if (!sig.inputsToBuffers().empty()) { out << "inputsToBuffers: {\n"; for (const auto& [inputName, bufferName] : sig.inputsToBuffers()) { out << "\t" << inputName << " : " << bufferName << "\n"; } out << "}\n"; } if (!sig.inputsToTensorConstants().empty()) { out << "inputsToTensorConstants: {\n"; for (const auto& [inputName, tensorConstantName] : sig.inputsToTensorConstants()) { out << "\t" << inputName << " : " << tensorConstantName << "\n"; } out << "}\n"; } if (!sig.inputsToCustomObjs().empty()) { out << "inputsToCustomObjs: {\n"; for (const auto& [inputName, customObjName] : sig.inputsToCustomObjs()) { out << "\t" << inputName << " : " << customObjName << "\n"; } out << "}\n"; } if (!sig.userOutputs().empty()) { out << "userOutputs: {\n"; for (const auto& outputName : sig.userOutputs()) { out << "\t" << outputName.value_or("Constant") << "\n"; } out << "}\n"; } if (!sig.buffersToMutate().empty()) { out << "buffersToMutate: {\n"; for (const auto& [outputName, mutatedBufferName] : sig.buffersToMutate()) { out << "\t" << outputName << " : " << mutatedBufferName << "\n"; } out << "}\n"; } if (!sig.userInputsToMutate().empty()) { out << "userInputsToMutate: {\n"; for (const auto& [outputName, mutatedUserInputName] : sig.userInputsToMutate()) { out << "\t" << outputName << " : " << mutatedUserInputName << "\n"; } out << "}\n"; } if (sig.hasBackward()) { if (!sig.gradientsToParameters().empty()) { out << "gradientsToParameters: {\n"; for (const auto& [outputName, paramName] : sig.gradientsToParameters()) { out << "\t" << outputName << " : " << paramName << "\n"; } out << "}\n"; } if (!sig.gradientsToUserInputs().empty()) { out << "gradientsToUserInputs: {\n"; for (const auto& [outputName, userInputName] : sig.gradientsToUserInputs()) { out << "\t" << outputName << " : " << userInputName << "\n"; } out << "}\n"; } out << "lossOutput: " << sig.lossOutput() << "\n"; } return out; } } // namespace torch::nativert