|
|
@ -17,7 +17,9 @@ from sqlalchemy.exc import OperationalError |
|
|
import json |
|
|
import json |
|
|
import re |
|
|
import re |
|
|
from decimal import Decimal |
|
|
from decimal import Decimal |
|
|
|
|
|
import sqlparse |
|
|
|
|
|
from sqlparse.sql import Identifier, IdentifierList, Function, Token |
|
|
|
|
|
from sqlparse.tokens import Keyword, DML |
|
|
class MetaSecurityService: |
|
|
class MetaSecurityService: |
|
|
""" |
|
|
""" |
|
|
数据源安全管理模块服务层 |
|
|
数据源安全管理模块服务层 |
|
|
@ -304,7 +306,7 @@ class MetaSecurityService: |
|
|
dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) |
|
|
dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) |
|
|
# await test_connection(dbConnent) |
|
|
# await test_connection(dbConnent) |
|
|
#3获取sql中涉及的表名 |
|
|
#3获取sql中涉及的表名 |
|
|
sqlScheamAndTable =await cls.get_tables_from_sql(page_object.sqlStr) |
|
|
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr) |
|
|
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"]) |
|
|
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"]) |
|
|
#4.执行原始sql |
|
|
#4.执行原始sql |
|
|
result = await cls.execute_sql(dbConnent, oldStrSql,"原始") |
|
|
result = await cls.execute_sql(dbConnent, oldStrSql,"原始") |
|
|
@ -405,74 +407,6 @@ class MetaSecurityService: |
|
|
except SQLAlchemyError as e: |
|
|
except SQLAlchemyError as e: |
|
|
raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}") |
|
|
raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}") |
|
|
|
|
|
|
|
|
# async def get_tables_from_sql(sql_query: str): |
|
|
|
|
|
# """ |
|
|
|
|
|
# 解析 SQL 查询,提取所有 Schema 和 Table 名称,并确保表名包含模式名(schema.table)。 |
|
|
|
|
|
|
|
|
|
|
|
# :param sql_query: SQL 查询字符串 |
|
|
|
|
|
# :return: {'schemas': [...], 'table_names': [...]} |
|
|
|
|
|
# :raises ServiceException: 如果 SQL 未使用 schema.table 结构,则抛出异常 |
|
|
|
|
|
# """ |
|
|
|
|
|
# # ✅ 改进正则:支持 `FROM a.o, b.x JOIN c.y` |
|
|
|
|
|
# table_section_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE)\s+([\w\.\s,]+)" |
|
|
|
|
|
|
|
|
|
|
|
# table_sections = re.findall(table_section_pattern, sql_query, re.DOTALL) |
|
|
|
|
|
|
|
|
|
|
|
# if not table_sections: |
|
|
|
|
|
# raise ServiceException(data='', message='SQL 解析失败,未找到表名') |
|
|
|
|
|
# # 解析多个表(用 `,` 和 `JOIN` 拆分) |
|
|
|
|
|
# for section in table_sections: |
|
|
|
|
|
# tables = re.split(r"\s*,\s*|\s+JOIN\s+", section, flags=re.IGNORECASE) |
|
|
|
|
|
# for table in tables: |
|
|
|
|
|
# table = table.strip().split()[0] # 取 `schema.table`,忽略别名 |
|
|
|
|
|
# if "." not in table: |
|
|
|
|
|
# raise ServiceException( |
|
|
|
|
|
# data='', |
|
|
|
|
|
# message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
|
|
|
# ) |
|
|
|
|
|
# return table_sections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
|
|
|
""" |
|
|
|
|
|
解析 SQL 查询,提取所有 schema.table 名称(支持嵌套子查询、别名、JOIN、INTO、UPDATE)。 |
|
|
|
|
|
自动排除字段引用与无模式表。 |
|
|
|
|
|
""" |
|
|
|
|
|
# 1️⃣ 清理注释与多余空白 |
|
|
|
|
|
sql_query = re.sub(r"--.*?$", "", sql_query, flags=re.MULTILINE) |
|
|
|
|
|
sql_query = re.sub(r"/\*.*?\*/", "", sql_query, flags=re.DOTALL) |
|
|
|
|
|
sql_query = " ".join(sql_query.split()) |
|
|
|
|
|
|
|
|
|
|
|
# 2️⃣ 匹配 FROM/JOIN/INTO/UPDATE 后面的 schema.table |
|
|
|
|
|
pattern = re.compile( |
|
|
|
|
|
r"""(?ix) |
|
|
|
|
|
(?:FROM|JOIN|INTO|UPDATE)\s+ # SQL 关键字 |
|
|
|
|
|
(?!\() # 排除子查询 |
|
|
|
|
|
(?P<schema>["'`]?[A-Za-z_][\w\$]*["'`]?) # schema |
|
|
|
|
|
\. # . |
|
|
|
|
|
(?P<table>["'`]?[A-Za-z_][\w\$]*["'`]?) # table |
|
|
|
|
|
\b |
|
|
|
|
|
""", |
|
|
|
|
|
re.VERBOSE |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# 3️⃣ 使用 finditer,逐个安全提取匹配项 |
|
|
|
|
|
table_names = set() |
|
|
|
|
|
for m in pattern.finditer(sql_query): |
|
|
|
|
|
schema_raw = m.group("schema") |
|
|
|
|
|
table_raw = m.group("table") |
|
|
|
|
|
if not schema_raw or not table_raw: |
|
|
|
|
|
continue |
|
|
|
|
|
schema =unquote_ident(schema_raw) |
|
|
|
|
|
table = unquote_ident(table_raw) |
|
|
|
|
|
table_names.add(f"{schema}.{table}") |
|
|
|
|
|
|
|
|
|
|
|
# 4️⃣ 检查结果 |
|
|
|
|
|
if not table_names: |
|
|
|
|
|
raise ServiceException(data='', message="SQL 解析失败,未找到任何 schema.table 结构") |
|
|
|
|
|
|
|
|
|
|
|
return list(table_names) |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): |
|
|
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): |
|
|
@ -734,13 +668,6 @@ async def replace_table_with_subquery(ctrSqlDict, oldStrSql): |
|
|
original_table = match.group(2) |
|
|
original_table = match.group(2) |
|
|
alias_part = match.group(3) # " AS xxx" 或 " xxx" |
|
|
alias_part = match.group(3) # " AS xxx" 或 " xxx" |
|
|
alias_name = match.group(4) # xxx |
|
|
alias_name = match.group(4) # xxx |
|
|
|
|
|
|
|
|
# 动态获取子查询 |
|
|
|
|
|
if original_table in ctrSqlDict: |
|
|
|
|
|
# 使用 ctrSqlDict 中的子查询替换表名 |
|
|
|
|
|
replaced = f"{keyword} ({ctrSqlDict[original_table]}) {alias_part}" |
|
|
|
|
|
else: |
|
|
|
|
|
# 默认处理逻辑:判断 alias 是否为关键字 |
|
|
|
|
|
sql_keywords = { |
|
|
sql_keywords = { |
|
|
"SELECT", "INSERT", "UPDATE", "DELETE", "MERGE", "TRUNCATE", |
|
|
"SELECT", "INSERT", "UPDATE", "DELETE", "MERGE", "TRUNCATE", |
|
|
"VALUES", "RETURNING", "FROM", "WHERE", "GROUP", "HAVING", "ORDER", |
|
|
"VALUES", "RETURNING", "FROM", "WHERE", "GROUP", "HAVING", "ORDER", |
|
|
@ -753,6 +680,13 @@ async def replace_table_with_subquery(ctrSqlDict, oldStrSql): |
|
|
"ALL", "SOME", "FETCH", "NEXT", "ONLY", "ASC", "DESC", "GRANT", "REVOKE", "ROLE", |
|
|
"ALL", "SOME", "FETCH", "NEXT", "ONLY", "ASC", "DESC", "GRANT", "REVOKE", "ROLE", |
|
|
"USER", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", |
|
|
"USER", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", |
|
|
} |
|
|
} |
|
|
|
|
|
# 动态获取子查询 |
|
|
|
|
|
if original_table in ctrSqlDict and alias_name not in sql_keywords: |
|
|
|
|
|
# 使用 ctrSqlDict 中的子查询替换表名 |
|
|
|
|
|
replaced = f"{keyword} ({ctrSqlDict[original_table]}) {alias_part}" |
|
|
|
|
|
else: |
|
|
|
|
|
# 默认处理逻辑:判断 alias 是否为关键字 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if alias_name and alias_name.upper().split()[0] not in sql_keywords: |
|
|
if alias_name and alias_name.upper().split()[0] not in sql_keywords: |
|
|
replaced = f"{keyword} ({subquery}) {alias_part}" |
|
|
replaced = f"{keyword} ({subquery}) {alias_part}" |
|
|
@ -835,3 +769,93 @@ def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> |
|
|
raise ValueError(f"不支持的数据库类型: {db_type}") |
|
|
raise ValueError(f"不支持的数据库类型: {db_type}") |
|
|
|
|
|
|
|
|
return newStrSql |
|
|
return newStrSql |
|
|
|
|
|
def _extract_identifiers(token): |
|
|
|
|
|
""" |
|
|
|
|
|
从 Identifier 或 IdentifierList 中抽取 (schema, table) 对。 |
|
|
|
|
|
返回格式为 'schema.table'(如果有 schema),否则 None。 |
|
|
|
|
|
""" |
|
|
|
|
|
if isinstance(token, Identifier): |
|
|
|
|
|
real_name = token.get_real_name() # table |
|
|
|
|
|
parent_name = token.get_parent_name() # schema if exists |
|
|
|
|
|
if real_name and parent_name: |
|
|
|
|
|
return f"{parent_name}.{real_name}" |
|
|
|
|
|
# 处理像 schema.table AS alias 这种形式 |
|
|
|
|
|
# token.get_name() 返回 alias 或 table,根据需要可扩展 |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
|
|
|
""" |
|
|
|
|
|
使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重) |
|
|
|
|
|
支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。 |
|
|
|
|
|
只返回包含 schema 的标识符(即有点号的)。 |
|
|
|
|
|
""" |
|
|
|
|
|
parsed = sqlparse.parse(sql_query) |
|
|
|
|
|
tables = set() |
|
|
|
|
|
|
|
|
|
|
|
for stmt in parsed: |
|
|
|
|
|
# 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句 |
|
|
|
|
|
for token in stmt.tokens: |
|
|
|
|
|
# 忽略函数、子查询整体(它们会在自己的 stmt 中被处理) |
|
|
|
|
|
# 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList |
|
|
|
|
|
if token.is_group: |
|
|
|
|
|
# 递归遍历 group 内的 token |
|
|
|
|
|
for t in token.flatten(): |
|
|
|
|
|
# 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...)) |
|
|
|
|
|
# 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard) |
|
|
|
|
|
# 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens) |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
# 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList |
|
|
|
|
|
for token in stmt.tokens: |
|
|
|
|
|
if token.ttype is DML and token.normalized.upper() in ("UPDATE",): |
|
|
|
|
|
# UPDATE table_name ... |
|
|
|
|
|
# 下一个非空白 token 往往是 Identifier |
|
|
|
|
|
nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True) |
|
|
|
|
|
if nxt: |
|
|
|
|
|
name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
|
|
|
|
|
|
# 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况 |
|
|
|
|
|
# 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier |
|
|
|
|
|
idx = 0 |
|
|
|
|
|
tokens = list(stmt.tokens) |
|
|
|
|
|
while idx < len(tokens): |
|
|
|
|
|
t = tokens[idx] |
|
|
|
|
|
if t.is_whitespace: |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"): |
|
|
|
|
|
# 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询) |
|
|
|
|
|
nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True) |
|
|
|
|
|
if nxt: |
|
|
|
|
|
tok = nxt[1] |
|
|
|
|
|
# 如果是 parenthesis -> 子查询,跳过 |
|
|
|
|
|
if tok.is_group and isinstance(tok, Function): |
|
|
|
|
|
# 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过 |
|
|
|
|
|
pass |
|
|
|
|
|
else: |
|
|
|
|
|
# 处理 Identifier 或 IdentifierList |
|
|
|
|
|
if isinstance(tok, Identifier): |
|
|
|
|
|
name = _extract_identifiers(tok) |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
elif isinstance(tok, IdentifierList): |
|
|
|
|
|
for ident in tok.get_identifiers(): |
|
|
|
|
|
name = _extract_identifiers(ident) |
|
|
|
|
|
if name: |
|
|
|
|
|
tables.add(name) |
|
|
|
|
|
else: |
|
|
|
|
|
# 可能是直接的 Name token 'schema.table'(未被识别为 Identifier) |
|
|
|
|
|
txt = tok.value |
|
|
|
|
|
if "." in txt: |
|
|
|
|
|
parts = txt.strip().strip('`"\'').split(".") |
|
|
|
|
|
if len(parts) == 2: |
|
|
|
|
|
tables.add(f"{parts[0]}.{parts[1]}") |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
if not tables: |
|
|
|
|
|
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构") |
|
|
|
|
|
return sorted(tables) |