153 lines
3.8 KiB
Python
153 lines
3.8 KiB
Python
import ast
|
|
import copy
|
|
from types import CodeType, ModuleType
|
|
from typing import Any, Dict, Mapping, Sequence, Union
|
|
|
|
ALLOWED_IMPORTS = {
|
|
"math",
|
|
"time",
|
|
"datetime",
|
|
"pandas",
|
|
"scipy",
|
|
"numpy",
|
|
"matplotlib",
|
|
"plotly",
|
|
"seaborn",
|
|
}
|
|
|
|
|
|
def _restricted_import(
|
|
name: str,
|
|
globals: Union[Mapping[str, object], None] = None,
|
|
locals: Union[Mapping[str, object], None] = None,
|
|
fromlist: Sequence[str] = (),
|
|
level: int = 0,
|
|
) -> ModuleType:
|
|
if name in ALLOWED_IMPORTS:
|
|
return __import__(name, globals, locals, fromlist, level)
|
|
raise ImportError(f"Import of module '{name}' is not allowed")
|
|
|
|
|
|
ALLOWED_BUILTINS = {
|
|
"abs": abs,
|
|
"all": all,
|
|
"any": any,
|
|
"ascii": ascii,
|
|
"bin": bin,
|
|
"bool": bool,
|
|
"bytearray": bytearray,
|
|
"bytes": bytes,
|
|
"chr": chr,
|
|
"complex": complex,
|
|
"divmod": divmod,
|
|
"enumerate": enumerate,
|
|
"filter": filter,
|
|
"float": float,
|
|
"format": format,
|
|
"frozenset": frozenset,
|
|
"getattr": getattr,
|
|
"hasattr": hasattr,
|
|
"hash": hash,
|
|
"hex": hex,
|
|
"int": int,
|
|
"isinstance": isinstance,
|
|
"issubclass": issubclass,
|
|
"iter": iter,
|
|
"len": len,
|
|
"list": list,
|
|
"map": map,
|
|
"max": max,
|
|
"min": min,
|
|
"next": next,
|
|
"oct": oct,
|
|
"ord": ord,
|
|
"pow": pow,
|
|
"print": print,
|
|
"range": range,
|
|
"repr": repr,
|
|
"reversed": reversed,
|
|
"round": round,
|
|
"set": set,
|
|
"setattr": setattr,
|
|
"slice": slice,
|
|
"sorted": sorted,
|
|
"str": str,
|
|
"sum": sum,
|
|
"tuple": tuple,
|
|
"type": type,
|
|
"zip": zip,
|
|
# Constants
|
|
"True": True,
|
|
"False": False,
|
|
"None": None,
|
|
"__import__": _restricted_import,
|
|
}
|
|
|
|
|
|
def _get_restricted_globals(__globals: Union[dict, None]) -> Any:
|
|
restricted_globals = copy.deepcopy(ALLOWED_BUILTINS)
|
|
if __globals:
|
|
restricted_globals.update(__globals)
|
|
return restricted_globals
|
|
|
|
|
|
class DunderVisitor(ast.NodeVisitor):
|
|
def __init__(self) -> None:
|
|
self.has_access_to_private_entity = False
|
|
|
|
def visit_Name(self, node: ast.Name) -> None:
|
|
if node.id.startswith("_"):
|
|
self.has_access_to_private_entity = True
|
|
self.generic_visit(node)
|
|
|
|
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
if node.attr.startswith("_"):
|
|
self.has_access_to_private_entity = True
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _contains_protected_access(code: str) -> bool:
|
|
tree = ast.parse(code)
|
|
dunder_visitor = DunderVisitor()
|
|
dunder_visitor.visit(tree)
|
|
return dunder_visitor.has_access_to_private_entity
|
|
|
|
|
|
def _verify_source_safety(__source: Union[str, bytes, CodeType]) -> None:
|
|
"""
|
|
Verify that the source is safe to execute. For now, this means that it
|
|
does not contain any references to private or dunder methods.
|
|
"""
|
|
if isinstance(__source, CodeType):
|
|
raise RuntimeError("Direct execution of CodeType is forbidden!")
|
|
if isinstance(__source, bytes):
|
|
__source = __source.decode()
|
|
if _contains_protected_access(__source):
|
|
raise RuntimeError(
|
|
"Execution of code containing references to private or dunder methods is forbidden!"
|
|
)
|
|
|
|
|
|
def safe_eval(
|
|
__source: Union[str, bytes, CodeType],
|
|
__globals: Union[Dict[str, Any], None] = None,
|
|
__locals: Union[Mapping[str, object], None] = None,
|
|
) -> Any:
|
|
"""
|
|
eval within safe global context.
|
|
"""
|
|
_verify_source_safety(__source)
|
|
return eval(__source, _get_restricted_globals(__globals), __locals)
|
|
|
|
|
|
def safe_exec(
|
|
__source: Union[str, bytes, CodeType],
|
|
__globals: Union[Dict[str, Any], None] = None,
|
|
__locals: Union[Mapping[str, object], None] = None,
|
|
) -> None:
|
|
"""
|
|
eval within safe global context.
|
|
"""
|
|
_verify_source_safety(__source)
|
|
return exec(__source, _get_restricted_globals(__globals), __locals)
|