Browse Source

数据安全升级

master
siyaqi 3 weeks ago
parent
commit
13c28481a3
  1. 180
      vue-fastapi-backend/module_admin/service/metasecurity_service.py

180
vue-fastapi-backend/module_admin/service/metasecurity_service.py

@ -13,13 +13,14 @@ from sqlalchemy.exc import SQLAlchemyError, DBAPIError
from sqlalchemy import text from sqlalchemy import text
from config.env import AppConfig from config.env import AppConfig
import requests import requests
from sqlalchemy.exc import OperationalError from sqlparse.sql import Identifier, IdentifierList, Parenthesis
import json import json
import asyncio import asyncio
import re import re
from decimal import Decimal from decimal import Decimal
import sqlparse import sqlglot
from sqlparse.sql import Identifier, IdentifierList, Function, Token from sqlglot.expressions import Table
from typing import Set
from sqlparse.tokens import Keyword, DML from sqlparse.tokens import Keyword, DML
class MetaSecurityService: class MetaSecurityService:
""" """
@ -916,79 +917,118 @@ def _extract_identifiers(token):
# token.get_name() 返回 alias 或 table,根据需要可扩展 # token.get_name() 返回 alias 或 table,根据需要可扩展
return None return None
# async def get_tables_from_sql(sql_query: str):
# """
# 使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重)
# 支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。
# 只返回包含 schema 的标识符(即有点号的)。
# """
# parsed = sqlparse.parse(sql_query)
# tables = set()
# for stmt in parsed:
# # 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句
# for token in stmt.tokens:
# # 忽略函数、子查询整体(它们会在自己的 stmt 中被处理)
# # 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList
# if token.is_group:
# # 递归遍历 group 内的 token
# for t in token.flatten():
# # 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...))
# # 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard)
# # 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens)
# pass
# # 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList
# for token in stmt.tokens:
# if token.ttype is DML and token.normalized.upper() in ("UPDATE",):
# # UPDATE table_name ...
# # 下一个非空白 token 往往是 Identifier
# nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True)
# if nxt:
# name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None
# if name:
# tables.add(name)
# # 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况
# # 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier
# idx = 0
# tokens = list(stmt.tokens)
# while idx < len(tokens):
# t = tokens[idx]
# if t.is_whitespace:
# idx += 1
# continue
# if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"):
# # 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询)
# nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True)
# if nxt:
# tok = nxt[1]
# # 如果是 parenthesis -> 子查询,跳过
# if tok.is_group and isinstance(tok, Function):
# # 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过
# pass
# else:
# # 处理 Identifier 或 IdentifierList
# if isinstance(tok, Identifier):
# name = _extract_identifiers(tok)
# if name:
# tables.add(name)
# elif isinstance(tok, IdentifierList):
# for ident in tok.get_identifiers():
# name = _extract_identifiers(ident)
# if name:
# tables.add(name)
# else:
# # 可能是直接的 Name token 'schema.table'(未被识别为 Identifier)
# txt = tok.value
# if "." in txt:
# parts = txt.strip().strip('`"\'').split(".")
# if len(parts) == 2:
# tables.add(f"{parts[0]}.{parts[1]}")
# idx += 1
# continue
# idx += 1
# if not tables:
# raise ValueError("SQL 解析失败,未找到任何 schema.table 结构")
# return sorted(tables)
async def get_tables_from_sql(sql_query: str): async def get_tables_from_sql(sql_query: str):
""" """
使用 sqlparse 解析 SQL 并返回 schema.table 列表去重 使用 sqlglot 解析 SQL 并返回 schema.table 列表去重排序
支持嵌套子查询函数别名JOININTOUPDATE 支持嵌套子查询JOIN别名UPDATEINSERT
只返回包含 schema 的标识符即有点号的
""" """
parsed = sqlparse.parse(sql_query) try:
tables = set() # 解析 SQL(自动支持多语句)
parsed = sqlglot.parse(sql_query)
except Exception as e:
raise ValueError(f"SQL 解析失败: {e}")
tables: Set[str] = set()
def extract_tables(expr):
"""
递归提取 Table 表达式
"""
if isinstance(expr, Table):
# 只取 schema.table
if expr.args.get("db") and expr.args.get("this"):
tables.add(f"{expr.args['db']}.{expr.args['this']}")
# 递归子节点
for child in expr.args.values():
if isinstance(child, list):
for c in child:
if isinstance(c, sqlglot.Expression):
extract_tables(c)
elif isinstance(child, sqlglot.Expression):
extract_tables(child)
for stmt in parsed: for stmt in parsed:
# 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 extract_tables(stmt)
for token in stmt.tokens:
# 忽略函数、子查询整体(它们会在自己的 stmt 中被处理)
# 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList
if token.is_group:
# 递归遍历 group 内的 token
for t in token.flatten():
# 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...))
# 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard)
# 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens)
pass
# 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList
for token in stmt.tokens:
if token.ttype is DML and token.normalized.upper() in ("UPDATE",):
# UPDATE table_name ...
# 下一个非空白 token 往往是 Identifier
nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True)
if nxt:
name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None
if name:
tables.add(name)
# 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况
# 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier
idx = 0
tokens = list(stmt.tokens)
while idx < len(tokens):
t = tokens[idx]
if t.is_whitespace:
idx += 1
continue
if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"):
# 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询)
nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True)
if nxt:
tok = nxt[1]
# 如果是 parenthesis -> 子查询,跳过
if tok.is_group and isinstance(tok, Function):
# 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过
pass
else:
# 处理 Identifier 或 IdentifierList
if isinstance(tok, Identifier):
name = _extract_identifiers(tok)
if name:
tables.add(name)
elif isinstance(tok, IdentifierList):
for ident in tok.get_identifiers():
name = _extract_identifiers(ident)
if name:
tables.add(name)
else:
# 可能是直接的 Name token 'schema.table'(未被识别为 Identifier)
txt = tok.value
if "." in txt:
parts = txt.strip().strip('`"\'').split(".")
if len(parts) == 2:
tables.add(f"{parts[0]}.{parts[1]}")
idx += 1
continue
idx += 1
if not tables: if not tables:
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") raise ValueError("SQL 解析失败,未找到任何 schema.table 结构")
return sorted(tables) return sorted(tables)
Loading…
Cancel
Save