diff --git a/vue-fastapi-backend/module_admin/service/metasecurity_service.py b/vue-fastapi-backend/module_admin/service/metasecurity_service.py index 518a3fd..62c31d4 100644 --- a/vue-fastapi-backend/module_admin/service/metasecurity_service.py +++ b/vue-fastapi-backend/module_admin/service/metasecurity_service.py @@ -13,13 +13,14 @@ from sqlalchemy.exc import SQLAlchemyError, DBAPIError from sqlalchemy import text from config.env import AppConfig import requests -from sqlalchemy.exc import OperationalError +from sqlparse.sql import Identifier, IdentifierList, Parenthesis import json import asyncio import re from decimal import Decimal -import sqlparse -from sqlparse.sql import Identifier, IdentifierList, Function, Token +import sqlglot +from sqlglot.expressions import Table +from typing import Set from sqlparse.tokens import Keyword, DML class MetaSecurityService: """ @@ -916,79 +917,118 @@ def _extract_identifiers(token): # token.get_name() 返回 alias 或 table,根据需要可扩展 return None +# async def get_tables_from_sql(sql_query: str): +# """ +# 使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重) +# 支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。 +# 只返回包含 schema 的标识符(即有点号的)。 +# """ +# parsed = sqlparse.parse(sql_query) +# tables = set() + +# for stmt in parsed: +# # 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 +# for token in stmt.tokens: +# # 忽略函数、子查询整体(它们会在自己的 stmt 中被处理) +# # 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList +# if token.is_group: +# # 递归遍历 group 内的 token +# for t in token.flatten(): +# # 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...)) +# # 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard) +# # 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens) +# pass + +# # 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList +# for token in stmt.tokens: +# if token.ttype is DML and token.normalized.upper() in ("UPDATE",): +# # UPDATE table_name ... +# # 下一个非空白 token 往往是 Identifier +# nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True) +# if nxt: +# name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None +# if name: +# tables.add(name) + +# # 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况 +# # 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier +# idx = 0 +# tokens = list(stmt.tokens) +# while idx < len(tokens): +# t = tokens[idx] +# if t.is_whitespace: +# idx += 1 +# continue +# if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"): +# # 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询) +# nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True) +# if nxt: +# tok = nxt[1] +# # 如果是 parenthesis -> 子查询,跳过 +# if tok.is_group and isinstance(tok, Function): +# # 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过 +# pass +# else: +# # 处理 Identifier 或 IdentifierList +# if isinstance(tok, Identifier): +# name = _extract_identifiers(tok) +# if name: +# tables.add(name) +# elif isinstance(tok, IdentifierList): +# for ident in tok.get_identifiers(): +# name = _extract_identifiers(ident) +# if name: +# tables.add(name) +# else: +# # 可能是直接的 Name token 'schema.table'(未被识别为 Identifier) +# txt = tok.value +# if "." in txt: +# parts = txt.strip().strip('`"\'').split(".") +# if len(parts) == 2: +# tables.add(f"{parts[0]}.{parts[1]}") +# idx += 1 +# continue +# idx += 1 + +# if not tables: +# raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") +# return sorted(tables) + + async def get_tables_from_sql(sql_query: str): """ - 使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重) - 支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。 - 只返回包含 schema 的标识符(即有点号的)。 + 使用 sqlglot 解析 SQL 并返回 schema.table 列表(去重,排序) + 支持嵌套子查询、JOIN、别名、UPDATE、INSERT 等 """ - parsed = sqlparse.parse(sql_query) - tables = set() + try: + # 解析 SQL(自动支持多语句) + parsed = sqlglot.parse(sql_query) + except Exception as e: + raise ValueError(f"SQL 解析失败: {e}") + + tables: Set[str] = set() + + def extract_tables(expr): + """ + 递归提取 Table 表达式 + """ + if isinstance(expr, Table): + # 只取 schema.table + if expr.args.get("db") and expr.args.get("this"): + tables.add(f"{expr.args['db']}.{expr.args['this']}") + # 递归子节点 + for child in expr.args.values(): + if isinstance(child, list): + for c in child: + if isinstance(c, sqlglot.Expression): + extract_tables(c) + elif isinstance(child, sqlglot.Expression): + extract_tables(child) for stmt in parsed: - # 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 - for token in stmt.tokens: - # 忽略函数、子查询整体(它们会在自己的 stmt 中被处理) - # 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList - if token.is_group: - # 递归遍历 group 内的 token - for t in token.flatten(): - # 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...)) - # 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard) - # 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens) - pass - - # 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList - for token in stmt.tokens: - if token.ttype is DML and token.normalized.upper() in ("UPDATE",): - # UPDATE table_name ... - # 下一个非空白 token 往往是 Identifier - nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True) - if nxt: - name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None - if name: - tables.add(name) - - # 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况 - # 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier - idx = 0 - tokens = list(stmt.tokens) - while idx < len(tokens): - t = tokens[idx] - if t.is_whitespace: - idx += 1 - continue - if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"): - # 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询) - nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True) - if nxt: - tok = nxt[1] - # 如果是 parenthesis -> 子查询,跳过 - if tok.is_group and isinstance(tok, Function): - # 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过 - pass - else: - # 处理 Identifier 或 IdentifierList - if isinstance(tok, Identifier): - name = _extract_identifiers(tok) - if name: - tables.add(name) - elif isinstance(tok, IdentifierList): - for ident in tok.get_identifiers(): - name = _extract_identifiers(ident) - if name: - tables.add(name) - else: - # 可能是直接的 Name token 'schema.table'(未被识别为 Identifier) - txt = tok.value - if "." in txt: - parts = txt.strip().strip('`"\'').split(".") - if len(parts) == 2: - tables.add(f"{parts[0]}.{parts[1]}") - idx += 1 - continue - idx += 1 + extract_tables(stmt) if not tables: raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") - return sorted(tables) \ No newline at end of file + + return sorted(tables)