|
|
@ -304,7 +304,9 @@ class MetaSecurityService: |
|
|
|
# dbConnent= cls.get_db_engine("postgresql",dataParams]) |
|
|
|
dbConnent= cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"]) |
|
|
|
# await test_connection(dbConnent) |
|
|
|
#3.执行原始sql |
|
|
|
#3获取sql中涉及的表名 |
|
|
|
sqlScheamAndTable =await cls.get_tables_from_sql(page_object.sqlStr) |
|
|
|
#4.执行原始sql |
|
|
|
result = await cls.execute_sql(dbConnent, page_object.sqlStr,"原始") |
|
|
|
if 3 in role_id_list: |
|
|
|
resultDict={ |
|
|
@ -314,16 +316,15 @@ class MetaSecurityService: |
|
|
|
} |
|
|
|
return resultDict |
|
|
|
|
|
|
|
#4.获取sql中涉及的表名 |
|
|
|
sqlTableNames =await cls.get_tables_from_sql(page_object.sqlStr) |
|
|
|
|
|
|
|
#5.根据表名获取数据库中的字段名 |
|
|
|
table_columns = await cls.get_columns_from_tables(dbConnent, sqlTableNames,dsDataResource["type"],) |
|
|
|
table_columns = await cls.get_columns_from_tables(dbConnent, sqlScheamAndTable,dsDataResource["type"],) |
|
|
|
|
|
|
|
#6.查询用户及该用户角色下的所有行列配置 |
|
|
|
tablesRowCol = {} |
|
|
|
|
|
|
|
# 遍历每个表名,获取对应的配置 |
|
|
|
for table_name in sqlTableNames: |
|
|
|
for table_name in sqlScheamAndTable: |
|
|
|
table_configs = await get_table_configs(query_db, page_object, user, role_id_list, table_name) |
|
|
|
tablesRowCol[table_name] = table_configs |
|
|
|
|
|
|
@ -403,15 +404,64 @@ class MetaSecurityService: |
|
|
|
return result_dict |
|
|
|
except SQLAlchemyError as e: |
|
|
|
raise RuntimeError(f"{sql_type}执行 SQL 查询时发生错误: {e}") |
|
|
|
async def get_tables_from_sql(sql_query:str): |
|
|
|
table_pattern = r"(FROM|JOIN|INTO|UPDATE)\s+([a-zA-Z0-9_]+(?:\.[a-zA-Z0-9_]+))" |
|
|
|
|
|
|
|
table_matches = re.findall(table_pattern, sql_query, re.IGNORECASE) |
|
|
|
|
|
|
|
table_names = [match[1] for match in table_matches] |
|
|
|
table_names = [match[1].split('.')[-1] for match in table_matches] # `split('.')[-1]` 取最后一部分,即表名 |
|
|
|
|
|
|
|
return table_names |
|
|
|
# async def get_tables_from_sql(sql_query: str): |
|
|
|
# """ |
|
|
|
# 解析 SQL 查询,提取所有 Schema 和 Table 名称,并确保表名包含模式名(schema.table)。 |
|
|
|
|
|
|
|
# :param sql_query: SQL 查询字符串 |
|
|
|
# :return: {'schemas': [...], 'table_names': [...]} |
|
|
|
# :raises ServiceException: 如果 SQL 未使用 schema.table 结构,则抛出异常 |
|
|
|
# """ |
|
|
|
# # ✅ 改进正则:支持 `FROM a.o, b.x JOIN c.y` |
|
|
|
# table_section_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE)\s+([\w\.\s,]+)" |
|
|
|
|
|
|
|
# table_sections = re.findall(table_section_pattern, sql_query, re.DOTALL) |
|
|
|
|
|
|
|
# if not table_sections: |
|
|
|
# raise ServiceException(data='', message='SQL 解析失败,未找到表名') |
|
|
|
# # 解析多个表(用 `,` 和 `JOIN` 拆分) |
|
|
|
# for section in table_sections: |
|
|
|
# tables = re.split(r"\s*,\s*|\s+JOIN\s+", section, flags=re.IGNORECASE) |
|
|
|
# for table in tables: |
|
|
|
# table = table.strip().split()[0] # 取 `schema.table`,忽略别名 |
|
|
|
# if "." not in table: |
|
|
|
# raise ServiceException( |
|
|
|
# data='', |
|
|
|
# message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
|
# ) |
|
|
|
# return table_sections |
|
|
|
async def get_tables_from_sql(sql_query: str): |
|
|
|
""" |
|
|
|
解析 SQL 查询,提取所有 Schema 和 Table 名称,并确保表名包含模式名(schema.table)。 |
|
|
|
|
|
|
|
:param sql_query: SQL 查询字符串 |
|
|
|
:return: {'schemas': [...], 'table_names': [...]} |
|
|
|
:raises ServiceException: 如果 SQL 未使用 schema.table 结构,则抛出异常 |
|
|
|
""" |
|
|
|
# ✅ 改进正则,支持 `FROM ... JOIN ...`,并适配换行符 |
|
|
|
table_section_pattern = r"(?i)(?:FROM|JOIN|INTO|UPDATE)\s+([\w\.\s,]+)" |
|
|
|
|
|
|
|
# 允许 `.` 匹配换行符,确保提取完整的 `FROM` 和 `JOIN` 语句 |
|
|
|
table_sections = re.findall(table_section_pattern, sql_query, re.DOTALL) |
|
|
|
|
|
|
|
if not table_sections: |
|
|
|
raise ServiceException(data='', message='SQL 解析失败,未找到表名') |
|
|
|
|
|
|
|
table_names = set() # 使用集合去重 |
|
|
|
for section in table_sections: |
|
|
|
# 按 `,` 或 `JOIN` 拆分,提取表名 |
|
|
|
tables = re.split(r"\s*,\s*|\s+JOIN\s+", section.strip(), flags=re.IGNORECASE) |
|
|
|
for table in tables: |
|
|
|
table = table.strip().split()[0] # 取 `schema.table`,忽略别名 |
|
|
|
if "." not in table: |
|
|
|
raise ServiceException( |
|
|
|
data='', |
|
|
|
message=f"SQL 中的表名必须携带模式名(schema.table),但发现了无模式的表:{table}" |
|
|
|
) |
|
|
|
table_names.add(table) |
|
|
|
|
|
|
|
return list(table_names) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str): |
|
|
@ -419,12 +469,21 @@ class MetaSecurityService: |
|
|
|
columns = {} |
|
|
|
query="" |
|
|
|
for table_name in table_names: |
|
|
|
schema, table = table_name.split(".") |
|
|
|
if db_type.lower() == "postgresql": |
|
|
|
# PostgreSQL: 使用 information_schema.columns 查询字段 |
|
|
|
query= f"SELECT column_name FROM information_schema.columns WHERE table_name ='{table_name}'" |
|
|
|
# PostgreSQL: 查询指定 schema 下的表字段 |
|
|
|
query = f""" |
|
|
|
SELECT column_name |
|
|
|
FROM information_schema.columns |
|
|
|
WHERE table_schema = '{schema}' AND table_name = '{table}' |
|
|
|
""" |
|
|
|
elif db_type.lower() == "mysql": |
|
|
|
# MySQL: 使用 INFORMATION_SCHEMA.COLUMNS 查询字段 |
|
|
|
query= f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME ='{table_name}'" |
|
|
|
# MySQL: 直接查询表字段(MySQL 没有 schema 的概念) |
|
|
|
query = f""" |
|
|
|
SELECT COLUMN_NAME |
|
|
|
FROM INFORMATION_SCHEMA.COLUMNS |
|
|
|
WHERE TABLE_NAME = '{table}' |
|
|
|
""" |
|
|
|
else: |
|
|
|
raise ValueError(f"暂不支持数据库类型: {db_type}") |
|
|
|
|
|
|
@ -447,29 +506,32 @@ def convert_decimal(obj): |
|
|
|
return [convert_decimal(item) for item in obj] |
|
|
|
return obj # 返回非 Decimal、dict 或 list 的值 |
|
|
|
async def get_table_configs(query_db, page_object, user, role_id_list, table_name): |
|
|
|
parts = table_name.split(".") |
|
|
|
schema, table = parts |
|
|
|
# 获取用户的列配置 |
|
|
|
user_col_list = await MetaSecurityDao.get_api_col_list( |
|
|
|
query_db, page_object.dbRCode, table_name, '0', user[0].user_id |
|
|
|
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取角色的列配置 |
|
|
|
role_col_list = [] |
|
|
|
for role_id in role_id_list: |
|
|
|
role_cols = await MetaSecurityDao.get_api_col_list( |
|
|
|
query_db, page_object.dbRCode, table_name, '1', role_id |
|
|
|
query_db, page_object.dbRCode, schema,table, '1', role_id |
|
|
|
) |
|
|
|
role_col_list.extend(role_cols) # 将每个角色的列配置合并到列表中 |
|
|
|
|
|
|
|
# 获取用户的行配置 |
|
|
|
user_row_list = await MetaSecurityDao.get_api_row_list( |
|
|
|
query_db, page_object.dbRCode, table_name, '0', user[0].user_id |
|
|
|
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id |
|
|
|
) |
|
|
|
|
|
|
|
# 获取角色的行配置 |
|
|
|
role_row_list = [] |
|
|
|
for role_id in role_id_list: |
|
|
|
role_rows = await MetaSecurityDao.get_api_row_list( |
|
|
|
query_db, page_object.dbRCode, table_name, '1', role_id |
|
|
|
query_db, page_object.dbRCode,schema,table, '1', role_id |
|
|
|
) |
|
|
|
role_row_list.extend(role_rows) # 将每个角色的行配置合并到列表中 |
|
|
|
isHave = any([ |
|
|
@ -567,7 +629,7 @@ async def generate_sql(tablesRowCol:dict, table_columns:dict): |
|
|
|
tab_col_value=row.ctrl_value.split(".") |
|
|
|
if len(tab_col_value) != 2: |
|
|
|
raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值") |
|
|
|
select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {tab_col_value[0]})" |
|
|
|
select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')" |
|
|
|
# 处理用户行配置 |
|
|
|
for row in config["user_row_list"]: |
|
|
|
# 仅仅对固定值有效,不加行限制 |
|
|
@ -587,7 +649,7 @@ async def generate_sql(tablesRowCol:dict, table_columns:dict): |
|
|
|
tab_col_value=row.ctrl_value.split(".") |
|
|
|
if len(tab_col_value) != 2: |
|
|
|
raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值") |
|
|
|
select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {tab_col_value[0]})" |
|
|
|
select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')" |
|
|
|
if select_rows.values(): |
|
|
|
where_conditions = " AND ".join(select_rows.values()) |
|
|
|
if where_conditions: |
|
|
|