/* * A C++ extension bridge with the Python pytree * serialization/unserialization format for torch.export. */ #pragma once #include #include #include #include #include #include #include namespace torch::nativert::detail { using torch::nativert::Value; class ITreeSpec; using ITreeFlattenFn = void (*)(const c10::IValue&, const ITreeSpec&, std::vector&); using ITreeUnflattenFn = c10::IValue (*)(std::vector, const nlohmann::json&); using ContextLoadFn = nlohmann::json (*)(std::string_view); using ITreeMapFn = c10::function_ref; using ITreeMapNoReturnFn = c10::function_ref; using IValueApplyFn = void (*)(ITreeMapNoReturnFn, const c10::IValue&, const ITreeSpec&); nlohmann::json defaultContextLoadFn(std::string_view); struct NodeDef { ITreeFlattenFn flattenFn; ITreeUnflattenFn unflattenFn; IValueApplyFn ivalueApplyFn; ContextLoadFn contextLoadFn = defaultContextLoadFn; }; class ITreeSpec { public: // Leaf node. ITreeSpec(const Value* value = nullptr, bool isUsed = true) : numIValues_(1), value_(value), isUsed_(isUsed) {} // Non leaf node. ITreeSpec( std::string_view uniformName, nlohmann::json context, std::vector children, NodeDef nodeDefCache); bool isIValue() const { return !uniformName_; } std::string_view uniformName() const { TORCH_CHECK(uniformName_); return uniformName_.value(); } const nlohmann::json& context() const { return context_; } const std::vector& contextKeys() const { return contextKeys_; } const auto& children() const { return children_; } const ITreeSpec& children(size_t i) const { return children_[i]; } const NodeDef& nodeDefCache() const { return nodeDefCache_; } size_t numIValues() const { return numIValues_; } bool allIValues() const { return allIValues_; } c10::TypePtr toAtenType() const; bool isUsed() const { return isUsed_; } const Value* value() const { return value_; } private: // Only non leaf nodes have names. // Examples of uniform name: "builtins.tuple", "builtins.dict". std::optional uniformName_; nlohmann::json context_; std::vector children_; std::vector contextKeys_; // Cached fields. NodeDef nodeDefCache_; size_t numIValues_; bool allIValues_ = true; const Value* value_; bool isUsed_; }; void registerPytreeNode(std::string_view typeName, NodeDef nodeDef); // Serialized json tree spec should be dumped from treespec_dumps() in // torch.utils._pytree directly . ITreeSpec itreeSpecLoads( std::string_view json, const std::vector& values); c10::IValue itreeUnflatten( std::vector ivalues, const ITreeSpec& spec); std::vector itreeFlatten( const c10::IValue& nested, const ITreeSpec& spec); std::vector itreeFlattenFromArgs( const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec); std::vector itreeFlattenToTensorList( const c10::IValue& nested, const ITreeSpec& spec); c10::IValue itreeMap( ITreeMapFn f, const c10::IValue& nested, const ITreeSpec& spec); c10::IValue TORCH_API argsToIValue( const std::vector& args, const std::unordered_map& kwargs); std:: pair, std::unordered_map> itreeMapArgs( ITreeMapFn f, const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec); void ivalueApply( ITreeMapNoReturnFn f, const c10::IValue& nested, const ITreeSpec& spec); void ivalueApplyFromArgs( ITreeMapNoReturnFn fn, const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec); } // namespace torch::nativert::detail