|
|
@ -13,13 +13,14 @@ from sqlalchemy.exc import SQLAlchemyError, DBAPIError |
|
|
from sqlalchemy import text |
|
|
from sqlalchemy import text |
|
|
from config.env import AppConfig |
|
|
from config.env import AppConfig |
|
|
import requests |
|
|
import requests |
|
|
from sqlalchemy.exc import OperationalError |
|
|
from sqlparse.sql import Identifier, IdentifierList, Parenthesis |
|
|
import json |
|
|
import json |
|
|
import asyncio |
|
|
import asyncio |
|
|
import re |
|
|
import re |
|
|
from decimal import Decimal |
|
|
from decimal import Decimal |
|
|
import sqlparse |
|
|
import sqlglot |
|
|
from sqlparse.sql import Identifier, IdentifierList, Function, Token |
|
|
from sqlglot.expressions import Table |
|
|
|
|
|
from typing import Set |
|
|
from sqlparse.tokens import Keyword, DML |
|
|
from sqlparse.tokens import Keyword, DML |
|
|
class MetaSecurityService: |
|
|
class MetaSecurityService: |
|
|
""" |
|
|
""" |
|
|
@ -916,79 +917,118 @@ def _extract_identifiers(token): |
|
|
# token.get_name() 返回 alias 或 table,根据需要可扩展 |
|
|
# token.get_name() 返回 alias 或 table,根据需要可扩展 |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
# async def get_tables_from_sql(sql_query: str): |
|
|
|
|
|
# """ |
|
|
|
|
|
# 使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重) |
|
|
|
|
|
# 支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。 |
|
|
|
|
|
# 只返回包含 schema 的标识符(即有点号的)。 |
|
|
|
|
|
# """ |
|
|
|
|
|
# parsed = sqlparse.parse(sql_query) |
|
|
|
|
|
# tables = set() |
|
|
|
|
|
|
|
|
|
|
|
# for stmt in parsed: |
|
|
|
|
|
# # 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 |
|
|
|
|
|
# for token in stmt.tokens: |
|
|
|
|
|
# # 忽略函数、子查询整体(它们会在自己的 stmt 中被处理) |
|
|
|
|
|
# # 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList |
|
|
|
|
|
# if token.is_group: |
|
|
|
|
|
# # 递归遍历 group 内的 token |
|
|
|
|
|
# for t in token.flatten(): |
|
|
|
|
|
# # 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...)) |
|
|
|
|
|
# # 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard) |
|
|
|
|
|
# # 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens) |
|
|
|
|
|
# pass |
|
|
|
|
|
|
|
|
|
|
|
# # 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList |
|
|
|
|
|
# for token in stmt.tokens: |
|
|
|
|
|
# if token.ttype is DML and token.normalized.upper() in ("UPDATE",): |
|
|
|
|
|
# # UPDATE table_name ... |
|
|
|
|
|
# # 下一个非空白 token 往往是 Identifier |
|
|
|
|
|
# nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True) |
|
|
|
|
|
# if nxt: |
|
|
|
|
|
# name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None |
|
|
|
|
|
# if name: |
|
|
|
|
|
# tables.add(name) |
|
|
|
|
|
|
|
|
|
|
|
# # 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况 |
|
|
|
|
|
# # 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier |
|
|
|
|
|
# idx = 0 |
|
|
|
|
|
# tokens = list(stmt.tokens) |
|
|
|
|
|
# while idx < len(tokens): |
|
|
|
|
|
# t = tokens[idx] |
|
|
|
|
|
# if t.is_whitespace: |
|
|
|
|
|
# idx += 1 |
|
|
|
|
|
# continue |
|
|
|
|
|
# if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"): |
|
|
|
|
|
# # 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询) |
|
|
|
|
|
# nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True) |
|
|
|
|
|
# if nxt: |
|
|
|
|
|
# tok = nxt[1] |
|
|
|
|
|
# # 如果是 parenthesis -> 子查询,跳过 |
|
|
|
|
|
# if tok.is_group and isinstance(tok, Function): |
|
|
|
|
|
# # 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过 |
|
|
|
|
|
# pass |
|
|
|
|
|
# else: |
|
|
|
|
|
# # 处理 Identifier 或 IdentifierList |
|
|
|
|
|
# if isinstance(tok, Identifier): |
|
|
|
|
|
# name = _extract_identifiers(tok) |
|
|
|
|
|
# if name: |
|
|
|
|
|
# tables.add(name) |
|
|
|
|
|
# elif isinstance(tok, IdentifierList): |
|
|
|
|
|
# for ident in tok.get_identifiers(): |
|
|
|
|
|
# name = _extract_identifiers(ident) |
|
|
|
|
|
# if name: |
|
|
|
|
|
# tables.add(name) |
|
|
|
|
|
# else: |
|
|
|
|
|
# # 可能是直接的 Name token 'schema.table'(未被识别为 Identifier) |
|
|
|
|
|
# txt = tok.value |
|
|
|
|
|
# if "." in txt: |
|
|
|
|
|
# parts = txt.strip().strip('`"\'').split(".") |
|
|
|
|
|
# if len(parts) == 2: |
|
|
|
|
|
# tables.add(f"{parts[0]}.{parts[1]}") |
|
|
|
|
|
# idx += 1 |
|
|
|
|
|
# continue |
|
|
|
|
|
# idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
# if not tables: |
|
|
|
|
|
# raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") |
|
|
|
|
|
# return sorted(tables) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
""" |
|
|
""" |
|
|
使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重) |
|
|
使用 sqlglot 解析 SQL 并返回 schema.table 列表(去重,排序) |
|
|
支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。 |
|
|
支持嵌套子查询、JOIN、别名、UPDATE、INSERT 等 |
|
|
只返回包含 schema 的标识符(即有点号的)。 |
|
|
|
|
|
""" |
|
|
""" |
|
|
parsed = sqlparse.parse(sql_query) |
|
|
try: |
|
|
tables = set() |
|
|
# 解析 SQL(自动支持多语句) |
|
|
|
|
|
parsed = sqlglot.parse(sql_query) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise ValueError(f"SQL 解析失败: {e}") |
|
|
|
|
|
|
|
|
|
|
|
tables: Set[str] = set() |
|
|
|
|
|
|
|
|
|
|
|
def extract_tables(expr): |
|
|
|
|
|
""" |
|
|
|
|
|
递归提取 Table 表达式 |
|
|
|
|
|
""" |
|
|
|
|
|
if isinstance(expr, Table): |
|
|
|
|
|
# 只取 schema.table |
|
|
|
|
|
if expr.args.get("db") and expr.args.get("this"): |
|
|
|
|
|
tables.add(f"{expr.args['db']}.{expr.args['this']}") |
|
|
|
|
|
# 递归子节点 |
|
|
|
|
|
for child in expr.args.values(): |
|
|
|
|
|
if isinstance(child, list): |
|
|
|
|
|
for c in child: |
|
|
|
|
|
if isinstance(c, sqlglot.Expression): |
|
|
|
|
|
extract_tables(c) |
|
|
|
|
|
elif isinstance(child, sqlglot.Expression): |
|
|
|
|
|
extract_tables(child) |
|
|
|
|
|
|
|
|
for stmt in parsed: |
|
|
for stmt in parsed: |
|
|
# 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 |
|
|
extract_tables(stmt) |
|
|
for token in stmt.tokens: |
|
|
|
|
|
# 忽略函数、子查询整体(它们会在自己的 stmt 中被处理) |
|
|
|
|
|
# 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList |
|
|
|
|
|
if token.is_group: |
|
|
|
|
|
# 递归遍历 group 内的 token |
|
|
|
|
|
for t in token.flatten(): |
|
|
|
|
|
# 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...)) |
|
|
|
|
|
# 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard) |
|
|
|
|
|
# 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens) |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
# 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList |
|
|
|
|
|
for token in stmt.tokens: |
|
|
|
|
|
if token.ttype is DML and token.normalized.upper() in ("UPDATE",): |
|
|
|
|
|
# UPDATE table_name ... |
|
|
|
|
|
# 下一个非空白 token 往往是 Identifier |
|
|
|
|
|
nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True) |
|
|
|
|
|
if nxt: |
|
|
|
|
|
name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
|
|
|
|
|
|
# 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况 |
|
|
|
|
|
# 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier |
|
|
|
|
|
idx = 0 |
|
|
|
|
|
tokens = list(stmt.tokens) |
|
|
|
|
|
while idx < len(tokens): |
|
|
|
|
|
t = tokens[idx] |
|
|
|
|
|
if t.is_whitespace: |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"): |
|
|
|
|
|
# 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询) |
|
|
|
|
|
nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True) |
|
|
|
|
|
if nxt: |
|
|
|
|
|
tok = nxt[1] |
|
|
|
|
|
# 如果是 parenthesis -> 子查询,跳过 |
|
|
|
|
|
if tok.is_group and isinstance(tok, Function): |
|
|
|
|
|
# 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过 |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
|
|
# 处理 Identifier 或 IdentifierList |
|
|
|
|
|
if isinstance(tok, Identifier): |
|
|
|
|
|
name = _extract_identifiers(tok) |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
elif isinstance(tok, IdentifierList): |
|
|
|
|
|
for ident in tok.get_identifiers(): |
|
|
|
|
|
name = _extract_identifiers(ident) |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
else: |
|
|
|
|
|
# 可能是直接的 Name token 'schema.table'(未被识别为 Identifier) |
|
|
|
|
|
txt = tok.value |
|
|
|
|
|
if "." in txt: |
|
|
|
|
|
parts = txt.strip().strip('`"\'').split(".") |
|
|
|
|
|
if len(parts) == 2: |
|
|
|
|
|
tables.add(f"{parts[0]}.{parts[1]}") |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
if not tables: |
|
|
if not tables: |
|
|
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") |
|
|
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") |
|
|
|
|
|
|
|
|
return sorted(tables) |
|
|
return sorted(tables) |