faiss_rag_enterprise/llama_index/exec_utils.py

153 lines
3.8 KiB
Python

import ast
import copy
from types import CodeType, ModuleType
from typing import Any, Dict, Mapping, Sequence, Union
ALLOWED_IMPORTS = {
"math",
"time",
"datetime",
"pandas",
"scipy",
"numpy",
"matplotlib",
"plotly",
"seaborn",
}
def _restricted_import(
name: str,
globals: Union[Mapping[str, object], None] = None,
locals: Union[Mapping[str, object], None] = None,
fromlist: Sequence[str] = (),
level: int = 0,
) -> ModuleType:
if name in ALLOWED_IMPORTS:
return __import__(name, globals, locals, fromlist, level)
raise ImportError(f"Import of module '{name}' is not allowed")
ALLOWED_BUILTINS = {
"abs": abs,
"all": all,
"any": any,
"ascii": ascii,
"bin": bin,
"bool": bool,
"bytearray": bytearray,
"bytes": bytes,
"chr": chr,
"complex": complex,
"divmod": divmod,
"enumerate": enumerate,
"filter": filter,
"float": float,
"format": format,
"frozenset": frozenset,
"getattr": getattr,
"hasattr": hasattr,
"hash": hash,
"hex": hex,
"int": int,
"isinstance": isinstance,
"issubclass": issubclass,
"iter": iter,
"len": len,
"list": list,
"map": map,
"max": max,
"min": min,
"next": next,
"oct": oct,
"ord": ord,
"pow": pow,
"print": print,
"range": range,
"repr": repr,
"reversed": reversed,
"round": round,
"set": set,
"setattr": setattr,
"slice": slice,
"sorted": sorted,
"str": str,
"sum": sum,
"tuple": tuple,
"type": type,
"zip": zip,
# Constants
"True": True,
"False": False,
"None": None,
"__import__": _restricted_import,
}
def _get_restricted_globals(__globals: Union[dict, None]) -> Any:
restricted_globals = copy.deepcopy(ALLOWED_BUILTINS)
if __globals:
restricted_globals.update(__globals)
return restricted_globals
class DunderVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.has_access_to_private_entity = False
def visit_Name(self, node: ast.Name) -> None:
if node.id.startswith("_"):
self.has_access_to_private_entity = True
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
if node.attr.startswith("_"):
self.has_access_to_private_entity = True
self.generic_visit(node)
def _contains_protected_access(code: str) -> bool:
tree = ast.parse(code)
dunder_visitor = DunderVisitor()
dunder_visitor.visit(tree)
return dunder_visitor.has_access_to_private_entity
def _verify_source_safety(__source: Union[str, bytes, CodeType]) -> None:
"""
Verify that the source is safe to execute. For now, this means that it
does not contain any references to private or dunder methods.
"""
if isinstance(__source, CodeType):
raise RuntimeError("Direct execution of CodeType is forbidden!")
if isinstance(__source, bytes):
__source = __source.decode()
if _contains_protected_access(__source):
raise RuntimeError(
"Execution of code containing references to private or dunder methods is forbidden!"
)
def safe_eval(
__source: Union[str, bytes, CodeType],
__globals: Union[Dict[str, Any], None] = None,
__locals: Union[Mapping[str, object], None] = None,
) -> Any:
"""
eval within safe global context.
"""
_verify_source_safety(__source)
return eval(__source, _get_restricted_globals(__globals), __locals)
def safe_exec(
__source: Union[str, bytes, CodeType],
__globals: Union[Dict[str, Any], None] = None,
__locals: Union[Mapping[str, object], None] = None,
) -> None:
"""
eval within safe global context.
"""
_verify_source_safety(__source)
return exec(__source, _get_restricted_globals(__globals), __locals)