diff --git a/vue-fastapi-backend/module_admin/service/metasecurity_service.py b/vue-fastapi-backend/module_admin/service/metasecurity_service.py index e434dda..d66c155 100644 --- a/vue-fastapi-backend/module_admin/service/metasecurity_service.py +++ b/vue-fastapi-backend/module_admin/service/metasecurity_service.py @@ -431,37 +431,48 @@ class MetaSecurityService: # message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" # ) # return table_sections + + + async def get_tables_from_sql(sql_query: str): - """ - 解析 SQL 查询,提取所有 Schema 和 Table 名称,并确保表名包含模式名(schema.table)。 - - :param sql_query: SQL 查询字符串 - :return: {'schemas': [...], 'table_names': [...]} - :raises ServiceException: 如果 SQL 未使用 schema.table 结构,则抛出异常 - """ - # ✅ 改进正则,支持 `FROM ... JOIN ...`,并适配换行符 - table_section_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE)\s+([\w\.\s,]+)" - - # 允许 `.` 匹配换行符,确保提取完整的 `FROM` 和 `JOIN` 语句 - table_sections = re.findall(table_section_pattern, sql_query, re.DOTALL) - - if not table_sections: - raise ServiceException(data='', message='SQL 解析失败,未找到表名') - - table_names = set() # 使用集合去重 - for section in table_sections: - # 按 `,` 或 `JOIN` 拆分,提取表名 - tables = re.split(r"\s*,\s*|\s+JOIN\s+", section.strip(), flags=re.IGNORECASE) - for table in tables: - table = table.strip().split()[0] # 取 `schema.table`,忽略别名 - if "." not in table: - raise ServiceException( - data='', - message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" - ) - table_names.add(table) - - return list(table_names) + """ + 解析 SQL 查询,提取所有 schema.table 名称(支持嵌套子查询、别名、JOIN、INTO、UPDATE)。 + 自动排除字段引用与无模式表。 + """ + # 1️⃣ 清理注释与多余空白 + sql_query = re.sub(r"--.*?$", "", sql_query, flags=re.MULTILINE) + sql_query = re.sub(r"/\*.*?\*/", "", sql_query, flags=re.DOTALL) + sql_query = " ".join(sql_query.split()) + + # 2️⃣ 匹配 FROM/JOIN/INTO/UPDATE 后面的 schema.table + pattern = re.compile( + r"""(?ix) + (?:FROM|JOIN|INTO|UPDATE)\s+ # SQL 关键字 + (?!\() # 排除子查询 + (?P["'`]?[A-Za-z_][\w\$]*["'`]?) # schema + \. # . + (?P["'`]?[A-Za-z_][\w\$]*["'`]?) # table + \b + """, + re.VERBOSE + ) + + # 3️⃣ 使用 finditer,逐个安全提取匹配项 + table_names = set() + for m in pattern.finditer(sql_query): + schema_raw = m.group("schema") + table_raw = m.group("table") + if not schema_raw or not table_raw: + continue + schema =unquote_ident(schema_raw) + table = unquote_ident(table_raw) + table_names.add(f"{schema}.{table}") + + # 4️⃣ 检查结果 + if not table_names: + raise ServiceException(data='', message="SQL 解析失败,未找到任何 schema.table 结构") + + return list(table_names) @classmethod async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): @@ -494,7 +505,16 @@ class MetaSecurityService: columns[table_name] = [row["column_name"] for row in result] return columns - +def unquote_ident(name: str) -> str: + """ + 去除标识符首尾的成对引号(" ' `),例如: + '"public"' -> public, '`schema`' -> schema + """ + if not name: + return name + if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]: + return name[1:-1] + return name def convert_decimal(obj): if isinstance(obj, Decimal): return float(obj) # 或者 str(obj) 来保留精度 @@ -552,14 +572,11 @@ async def generate_sql(tablesRowCol:dict, table_columns:dict): # 1. 列控制 # 遍历每个表 - isHave=False no_configTable_name="" for table_name, table_configs in tablesRowCol.items(): - if table_configs["isHave"]: - isHave=True - else: + if not table_configs.get("isHave", False): no_configTable_name += table_name + "," - if not isHave: + if no_configTable_name: no_configTable_name = no_configTable_name.rstrip(',') raise ValueError(f"表:{no_configTable_name}均未配置行列数据安全") for table_name, config in tablesRowCol.items():