|
|
@ -431,35 +431,46 @@ class MetaSecurityService: |
|
|
# message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
# message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
# ) |
|
|
# ) |
|
|
# return table_sections |
|
|
# return table_sections |
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
|
|
|
""" |
|
|
|
|
|
解析 SQL 查询,提取所有 Schema 和 Table 名称,并确保表名包含模式名(schema.table)。 |
|
|
|
|
|
|
|
|
|
|
|
:param sql_query: SQL 查询字符串 |
|
|
|
|
|
:return: {'schemas': [...], 'table_names': [...]} |
|
|
|
|
|
:raises ServiceException: 如果 SQL 未使用 schema.table 结构,则抛出异常 |
|
|
|
|
|
""" |
|
|
|
|
|
# ✅ 改进正则,支持 `FROM ... JOIN ...`,并适配换行符 |
|
|
|
|
|
table_section_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE)\s+([\w\.\s,]+)" |
|
|
|
|
|
|
|
|
|
|
|
# 允许 `.` 匹配换行符,确保提取完整的 `FROM` 和 `JOIN` 语句 |
|
|
|
|
|
table_sections = re.findall(table_section_pattern, sql_query, re.DOTALL) |
|
|
|
|
|
|
|
|
|
|
|
if not table_sections: |
|
|
|
|
|
raise ServiceException(data='', message='SQL 解析失败,未找到表名') |
|
|
|
|
|
|
|
|
|
|
|
table_names = set() # 使用集合去重 |
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
for section in table_sections: |
|
|
""" |
|
|
# 按 `,` 或 `JOIN` 拆分,提取表名 |
|
|
解析 SQL 查询,提取所有 schema.table 名称(支持嵌套子查询、别名、JOIN、INTO、UPDATE)。 |
|
|
tables = re.split(r"\s*,\s*|\s+JOIN\s+", section.strip(), flags=re.IGNORECASE) |
|
|
自动排除字段引用与无模式表。 |
|
|
for table in tables: |
|
|
""" |
|
|
table = table.strip().split()[0] # 取 `schema.table`,忽略别名 |
|
|
# 1️⃣ 清理注释与多余空白 |
|
|
if "." not in table: |
|
|
sql_query = re.sub(r"--.*?$", "", sql_query, flags=re.MULTILINE) |
|
|
raise ServiceException( |
|
|
sql_query = re.sub(r"/\*.*?\*/", "", sql_query, flags=re.DOTALL) |
|
|
data='', |
|
|
sql_query = " ".join(sql_query.split()) |
|
|
message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
|
|
|
|
|
|
# 2️⃣ 匹配 FROM/JOIN/INTO/UPDATE 后面的 schema.table |
|
|
|
|
|
pattern = re.compile( |
|
|
|
|
|
r"""(?ix) |
|
|
|
|
|
(?:FROM|JOIN|INTO|UPDATE)\s+ # SQL 关键字 |
|
|
|
|
|
(?!\() # 排除子查询 |
|
|
|
|
|
(?P<schema>["'`]?[A-Za-z_][\w\$]*["'`]?) # schema |
|
|
|
|
|
\. # . |
|
|
|
|
|
(?P<table>["'`]?[A-Za-z_][\w\$]*["'`]?) # table |
|
|
|
|
|
\b |
|
|
|
|
|
""", |
|
|
|
|
|
re.VERBOSE |
|
|
) |
|
|
) |
|
|
table_names.add(table) |
|
|
|
|
|
|
|
|
# 3️⃣ 使用 finditer,逐个安全提取匹配项 |
|
|
|
|
|
table_names = set() |
|
|
|
|
|
for m in pattern.finditer(sql_query): |
|
|
|
|
|
schema_raw = m.group("schema") |
|
|
|
|
|
table_raw = m.group("table") |
|
|
|
|
|
if not schema_raw or not table_raw: |
|
|
|
|
|
continue |
|
|
|
|
|
schema =unquote_ident(schema_raw) |
|
|
|
|
|
table = unquote_ident(table_raw) |
|
|
|
|
|
table_names.add(f"{schema}.{table}") |
|
|
|
|
|
|
|
|
|
|
|
# 4️⃣ 检查结果 |
|
|
|
|
|
if not table_names: |
|
|
|
|
|
raise ServiceException(data='', message="SQL 解析失败,未找到任何 schema.table 结构") |
|
|
|
|
|
|
|
|
return list(table_names) |
|
|
return list(table_names) |
|
|
|
|
|
|
|
|
@ -494,7 +505,16 @@ class MetaSecurityService: |
|
|
columns[table_name] = [row["column_name"] for row in result] |
|
|
columns[table_name] = [row["column_name"] for row in result] |
|
|
|
|
|
|
|
|
return columns |
|
|
return columns |
|
|
|
|
|
def unquote_ident(name: str) -> str: |
|
|
|
|
|
""" |
|
|
|
|
|
去除标识符首尾的成对引号(" ' `),例如: |
|
|
|
|
|
'"public"' -> public, '`schema`' -> schema |
|
|
|
|
|
""" |
|
|
|
|
|
if not name: |
|
|
|
|
|
return name |
|
|
|
|
|
if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]: |
|
|
|
|
|
return name[1:-1] |
|
|
|
|
|
return name |
|
|
def convert_decimal(obj): |
|
|
def convert_decimal(obj): |
|
|
if isinstance(obj, Decimal): |
|
|
if isinstance(obj, Decimal): |
|
|
return float(obj) # 或者 str(obj) 来保留精度 |
|
|
return float(obj) # 或者 str(obj) 来保留精度 |
|
|
@ -552,14 +572,11 @@ async def generate_sql(tablesRowCol:dict, table_columns:dict): |
|
|
|
|
|
|
|
|
# 1. 列控制 |
|
|
# 1. 列控制 |
|
|
# 遍历每个表 |
|
|
# 遍历每个表 |
|
|
isHave=False |
|
|
|
|
|
no_configTable_name="" |
|
|
no_configTable_name="" |
|
|
for table_name, table_configs in tablesRowCol.items(): |
|
|
for table_name, table_configs in tablesRowCol.items(): |
|
|
if table_configs["isHave"]: |
|
|
if not table_configs.get("isHave", False): |
|
|
isHave=True |
|
|
|
|
|
else: |
|
|
|
|
|
no_configTable_name += table_name + "," |
|
|
no_configTable_name += table_name + "," |
|
|
if not isHave: |
|
|
if no_configTable_name: |
|
|
no_configTable_name = no_configTable_name.rstrip(',') |
|
|
no_configTable_name = no_configTable_name.rstrip(',') |
|
|
raise ValueError(f"表:{no_configTable_name}均未配置行列数据安全") |
|
|
raise ValueError(f"表:{no_configTable_name}均未配置行列数据安全") |
|
|
for table_name, config in tablesRowCol.items(): |
|
|
for table_name, config in tablesRowCol.items(): |
|
|
|