You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

1082 lines
47 KiB

from fastapi import Request
from exceptions.exception import ServiceException
from module_admin.dao.metaSecurity_dao import MetaSecurityDao
from module_admin.entity.vo.common_vo import CrudResponseModel
from module_admin.entity.vo.metasecurity_vo import MetaSecurityColModel, MetaSecurityRowModel,DeleteMetaSecurityModel,MetaSecurityApiModel
from utils.common_util import CamelCaseUtil
import uuid
from module_admin.dao.login_dao import login_by_account
from module_admin.dao.user_dao import UserDao
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import SQLAlchemyError, DBAPIError
from sqlalchemy import text
from config.env import AppConfig
import requests
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
import json
import asyncio
import re
from decimal import Decimal
import sqlglot
from sqlglot.expressions import Table
from sqlglot import exp ,parse_one
from typing import Set
from sqlparse.tokens import Keyword, DML
class MetaSecurityService:
"""
数据源安全管理模块服务层
"""
@classmethod
async def get_meta_security_col_list_services(
cls, query_db: AsyncSession, query_object: MetaSecurityColModel, is_page: bool = False
):
"""
获取列配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 列配置列表信息对象
"""
col_list_result = await MetaSecurityDao.get_meta_security_col_list(query_db, query_object, is_page)
return col_list_result
@classmethod
async def get_meta_security_row_list_services(
cls, query_db: AsyncSession, query_object: MetaSecurityRowModel, is_page: bool = False
):
"""
获取行配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 行配置列表信息对象
"""
row_list_result = await MetaSecurityDao.get_meta_security_row_list(query_db, query_object, is_page)
return row_list_result
@classmethod
async def get_meta_mdlName_list_services(
cls, query_db: AsyncSession, ssys_id: int
):
"""
获取行配置列表信息service
:param query_db: orm对象
:param query_object: 查询参数对象
:param is_page: 是否开启分页
:return: 行配置列表信息对象
"""
row_list_result = await MetaSecurityDao.get_schema_by_system(query_db, ssys_id)
return row_list_result
@classmethod
async def get_meta_security_col_by_id_services(cls, query_db: AsyncSession, colId: str):
"""
获取列配置详细信息service
:param query_db: orm对象
:param colId: 列配置ID
:return: 列配置详细信息对象
"""
col = await MetaSecurityDao.get_meta_security_col_by_id(query_db, colId)
if col:
result = MetaSecurityColModel(**CamelCaseUtil.transform_result(col))
else:
result = MetaSecurityColModel(**dict())
return result
@classmethod
async def get_meta_security_row_by_id_services(cls, query_db: AsyncSession, rowId: str):
"""
获取行配置详细信息service
:param query_db: orm对象
:param rowId: 行配置ID
:return: 行配置详细信息对象
"""
row = await MetaSecurityDao.get_meta_security_row_by_id(query_db, rowId)
if row:
result = MetaSecurityRowModel(**CamelCaseUtil.transform_result(row))
else:
result = MetaSecurityRowModel(**dict())
return result
@classmethod
async def add_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityColModel):
"""
新增列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 新增的列配置对象
:return: 新增列配置校验结果
"""
try:
if isinstance(page_object.obj_value, str) and page_object.obj_value:
obj_values = page_object.obj_value.split(",")
obj_names = page_object.obj_name.split(",")
for value, name in zip(obj_values, obj_names):
# 创建新的 page_object 实例,避免修改原始对象
new_page_object = MetaSecurityColModel(**page_object.model_dump(by_alias=True))
new_page_object.obj_value = value.strip() # 去除空格并赋值
new_page_object.obj_name = name.strip() # 去除空格并赋值
new_page_object.colId = str(uuid.uuid4())
# 调用 DAO 方法插入数据
await MetaSecurityDao.add_meta_security_col(query_db, new_page_object)
await query_db.commit()
return CrudResponseModel(is_success=True, message='新增列配置成功')
except Exception as e:
await query_db.rollback()
raise e
@classmethod
async def add_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityRowModel):
"""
新增行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 新增的行配置对象
:return: 新增行配置校验结果
"""
try:
if isinstance(page_object.obj_value, str) and page_object.obj_value:
obj_values = page_object.obj_value.split(",")
obj_names = page_object.obj_name.split(",")
for value, name in zip(obj_values, obj_names):
# 创建新的 page_object 实例,避免修改原始对象
new_page_object = MetaSecurityRowModel(**page_object.model_dump(by_alias=True))
new_page_object.obj_value = value.strip() # 去除空格并赋值
new_page_object.obj_name = name.strip() # 去除空格并赋值
new_page_object.rowId = str(uuid.uuid4())
# 调用 DAO 方法插入数据
await MetaSecurityDao.add_meta_security_row(query_db, new_page_object)
await query_db.commit()
# 缓存相关操作,如果需要
# await request.app.state.redis.set(...)
return CrudResponseModel(is_success=True, message='新增行配置成功')
except Exception as e:
await query_db.rollback()
raise e
@classmethod
async def col_detail_services(cls, query_db: AsyncSession, col: str):
"""
获取参数配置详细信息service
:param query_db: orm对象
:param config_id: 参数配置id
:return: 参数配置id对应的信息
"""
config = await MetaSecurityDao.get_meta_security_col_by_id(query_db, col)
if config:
result = MetaSecurityColModel(**CamelCaseUtil.transform_result(config))
else:
result = MetaSecurityColModel(**dict())
return result
@classmethod
async def row_detail_services(cls, query_db: AsyncSession, row_id: str):
"""
获取参数配置详细信息service
:param query_db: orm对象
:param config_id: 参数配置id
:return: 参数配置id对应的信息
"""
config = await MetaSecurityDao.get_meta_security_row_by_id(query_db, row_id)
if config:
result = MetaSecurityRowModel(**CamelCaseUtil.transform_result(config))
else:
result = MetaSecurityRowModel(**dict())
return result
@classmethod
async def edit_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityColModel):
"""
编辑列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 编辑的列配置对象
:return: 编辑列配置校验结果
"""
edit_col = page_object.model_dump(exclude_unset=True)
col_info = await cls.get_meta_security_col_by_id_services(query_db, page_object.colId)
if col_info:
try:
await MetaSecurityDao.update_meta_security_col(query_db, edit_col)
await query_db.commit()
# 缓存更新,如果需要
# await request.app.state.redis.set(...)
return CrudResponseModel(is_success=True, message='编辑列配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message=f'列配置{page_object.colId}不存在')
@classmethod
async def edit_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityRowModel):
"""
编辑行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 编辑的行配置对象
:return: 编辑行配置校验结果
"""
edit_row = page_object.model_dump(exclude_unset=True)
row_info = await cls.get_meta_security_row_by_id_services(query_db, page_object.rowId)
if row_info:
try:
await MetaSecurityDao.update_meta_security_row(query_db, edit_row)
await query_db.commit()
return CrudResponseModel(is_success=True, message='编辑行配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message=f'行配置{page_object.rowId}不存在')
@classmethod
async def delete_meta_security_col_services(cls, request: Request, query_db: AsyncSession, page_object: DeleteMetaSecurityModel):
"""
删除列配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 删除列配置对象
:return: 删除列配置校验结果
"""
if page_object.metaSecurity_ids:
col_id_list = page_object.metaSecurity_ids.split(',')
try:
for col_id in col_id_list:
col_info = await cls.get_meta_security_col_by_id_services(query_db, col_id)
if col_info:
# 校验不能删除的系统内置列
await MetaSecurityDao.delete_meta_security_col(query_db, col_id)
await query_db.commit()
return CrudResponseModel(is_success=True, message='删除列配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message='传入列配置ID为空')
@classmethod
async def delete_meta_security_row_services(cls, request: Request, query_db: AsyncSession, page_object: DeleteMetaSecurityModel):
"""
删除行配置服务
:param request: Request对象
:param query_db: orm对象
:param page_object: 删除行配置对象
:return: 删除行配置校验结果
"""
if page_object.metaSecurity_ids:
row_id_list = page_object.metaSecurity_ids.split(',')
try:
for row_id in row_id_list:
row_info = await cls.get_meta_security_row_by_id_services(query_db, row_id)
if row_info:
await MetaSecurityDao.delete_meta_security_row(query_db, row_id)
await query_db.commit()
return CrudResponseModel(is_success=True, message='删除行配置成功')
except Exception as e:
await query_db.rollback()
raise e
else:
raise ServiceException(message='传入行配置ID为空')
@classmethod
async def getMetaSercuitybysql(cls, request: Request, query_db: AsyncSession, page_object: MetaSecurityApiModel):
#1.校验用户
if not page_object.username:
raise ServiceException(data='', message='用户名不能为空!')
user = await login_by_account(query_db, page_object.username)
if not user:
raise ServiceException(data='', message='用户不存在')
if not page_object.password == user[0].password:
raise ServiceException(data='', message='用户密码错误!')
forbidden_keywords = ["UPDATE", "DELETE", "INSERT", "DROP", "ALTER", "TRUNCATE"]
pattern = re.compile(r"\b(" + "|".join(forbidden_keywords) + r")\b", re.IGNORECASE)
if pattern.search(page_object.sqlStr):
raise ServiceException(data='', message='SQL 中包含敏感词(UPDATE, DELETE, INSERT, DROP, ALTER, TRUNCATE),禁止执行!')
query_user = await UserDao.get_user_by_id(query_db, user_id=user[0].user_id)
role_id_list = [item.role_id for item in query_user.get('user_role_info')]
#2.测试数据源连接是否正常
# mysql
# dataParams ={"user":"dbf","password":"1q2w3e4r","address":"jdbc:mysql://47.113.147.166:3306","database":"dash_test_w","jdbcUrl":"jdbc:mysql://47.113.147.166:3306/dash_test_w","driverClassName":"com.mysql.cj.jdbc.Driver","validationQuery":"select 1"}
# postgresql
# dataParams ={"user":"testuser","password":"testpd","address":"jdbc:postgresql://47.121.207.11:5432","database":"zx2","jdbcUrl":"jdbc:postgresql://47.121.207.11:5432/zx2","driverClassName":"org.postgresql.Driver","validationQuery":"select version()"}
dsDataResource=await get_data_source_tree(request,page_object)
# dbConnent= cls.get_db_engine("postgresql",dataParams])
dbConnent= await cls.get_db_engine(dsDataResource["type"],dsDataResource["connectionParams"])
# await test_connection(dbConnent)
#3获取sql中涉及的表名
sqlScheamAndTable =await get_tables_from_sql(page_object.sqlStr)
oldStrSql= page_object.sqlStr
if has_pagination(oldStrSql,dsDataResource["type"]):
page_object.isPage=False
if page_object.isPage:
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"])
#4.执行原始sql
rawSqlRowCount = await cls.execute_sql(dbConnent, oldStrSql,"原始结果",dsDataResource["type"])
if 3 in role_id_list:
result = await cls.execute_sql(dbConnent, oldStrSql,"原始",dsDataResource["type"])
resultDict={
"ctrlSql": page_object.sqlStr,
"rawSqlRowCount": rawSqlRowCount,
"data": result,
"message":"数据安全管理员权限"
}
return resultDict
#5.根据表名获取数据库中的字段名
table_columns = await cls.get_columns_from_tables(dbConnent, sqlScheamAndTable,dsDataResource["type"],)
#6.查询用户及该用户角色下的所有行列配置
tablesRowCol = {}
# 遍历每个表名,获取对应的配置
for table_name in sqlScheamAndTable:
table_configs = await get_table_configs(query_db, page_object, user, role_id_list, table_name)
tablesRowCol[table_name] = table_configs
# 返回最终的结果字典
ctrSqlDict = await generate_sql(tablesRowCol,table_columns)
oldStrSql= page_object.sqlStr
if page_object.isPage:
oldStrSql=generate_pagination_sql(page_object,dsDataResource["type"])
#7.根据行列配置控制原始sql
newStrSql =await replace_table_with_subquery(ctrSqlDict,oldStrSql)
#8.执行结果
result = await cls.execute_sql(dbConnent, newStrSql,"控制后",dsDataResource["type"])
resultDict={
"ctrlSql": newStrSql,
"rawSqlRowCount": rawSqlRowCount,
"data": result,
"tablesRowCol":tablesRowCol
}
return resultDict
async def get_db_engine(db_type: str, db_params: dict):
try:
address = db_params['address']
# 1️⃣ 去掉 jdbc 前缀
jdbc_prefixes = {
"jdbc:mysql://": len("jdbc:mysql://"),
"jdbc:postgresql://": len("jdbc:postgresql://")
}
for prefix, length in jdbc_prefixes.items():
if address.startswith(prefix):
address = address[length:]
break
db_params['address'] = address
# 2️⃣ 构建连接字符串 + 超时
if db_type.lower() == "mysql":
conn_str = (
f"mysql+aiomysql://{db_params['user']}:{db_params['password']}"
f"@{db_params['address']}/{db_params['database']}"
)
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"connect_timeout": 5} # ⭐ 关键
)
elif db_type.lower() == "postgresql":
conn_str = (
f"postgresql+asyncpg://{db_params['user']}:{db_params['password']}"
f"@{db_params['address']}/{db_params['database']}"
)
engine = create_async_engine(
conn_str,
pool_pre_ping=True,
connect_args={"timeout": 5} # ⭐ 关键
)
else:
raise ValueError("不支持的数据库类型")
# 3️⃣ ⭐ 主动发起一次连接校验(不然一定会卡)
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return engine
except asyncio.TimeoutError:
raise ConnectionError("数据库连接超时,请检查地址、端口或防火墙")
except SQLAlchemyError as e:
raise ConnectionError(f"数据库连接失败: {e}")
except Exception as e:
raise RuntimeError(f"连接过程中发生未知错误: {e}")
@classmethod
async def execute_sql(cls, engine, sql_query: str, sql_type: str,db_type: str):
async_session = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
try:
async with async_session() as session:
await session.execute(text("SET statement_timeout = 30000"))
# ⭐ 原始数量
if sql_type == "原始结果":
count_sql = cls.build_count_sql(sql_query,db_type)
result = await session.execute(text(count_sql))
return result.scalar_one()
result = await session.execute(text(sql_query))
if result.returns_rows:
rows = result.fetchall()
columns = result.keys()
return [dict(zip(columns, row)) for row in rows]
await session.commit()
return []
except DBAPIError as e:
# ⭐ 核心:统一兜 PostgreSQL 超时
sqlstate = getattr(getattr(e, "orig", None), "sqlstate", None)
if sqlstate == "57014":
raise TimeoutError(
"SQL 执行超过 30 秒已被数据库中断,请先查询数据量或使用分页。"
)
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
except SQLAlchemyError as e:
raise RuntimeError(f"{sql_type} 执行 SQL 失败: {e}")
@classmethod
def build_count_sql(cls, sql: str, db_type: str = "POSTGRESQL") -> str:
dialect = db_type_to_sqlglot_dialect(db_type)
parsed = parse_one(sql, read=dialect)
select = parsed.find(exp.Select)
if select:
select.set("order", None)
select.set("limit", None)
cleaned_sql = parsed.sql(dialect=dialect)
return f"SELECT COUNT(*) AS cnt FROM ({cleaned_sql}) t"
@classmethod
async def get_columns_from_tables(cls, dbConnent, table_names, db_type: str):
"""查询每个表的字段信息,根据数据库类型调整查询"""
columns = {}
query=""
for table_name in table_names:
schema, table = table_name.split(".")
if db_type.lower() == "postgresql":
# PostgreSQL: 查询指定 schema 下的表字段
query = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}' AND table_name = '{table}'
"""
elif db_type.lower() == "mysql":
# MySQL: 直接查询表字段(MySQL 没有 schema 的概念)
query = f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{table}'
"""
else:
raise ValueError(f"暂不支持数据库类型: {db_type}")
# Execute the query for the specific table
result = await cls.execute_sql(dbConnent, query, "字段查询",db_type)
# 将结果转换为字典格式 {table_name: ['column1', 'column2', ...]}
columns[table_name] = [row["column_name"] for row in result]
return columns
def unquote_ident(name: str) -> str:
"""
去除标识符首尾的成对引号(" ' `),例如:
'"public"' -> public, '`schema`' -> schema
"""
if not name:
return name
if len(name) >= 2 and name[0] in ('"', "'", '`') and name[-1] == name[0]:
return name[1:-1]
return name
def db_type_to_sqlglot_dialect(db_type: str) -> str:
mapping = {
"MYSQL": "mysql",
"POSTGRESQL": "postgres",
"SQLSERVER": "tsql",
"ORACLE": "oracle",
}
dialect = mapping.get(db_type.upper())
if not dialect:
raise ValueError(f"Unsupported db_type: {db_type}")
return dialect
def convert_decimal(obj):
if isinstance(obj, Decimal):
return float(obj) # 或者 str(obj) 来保留精度
elif isinstance(obj, dict):
# 递归处理字典中的每个值
return {key: convert_decimal(value) for key, value in obj.items()}
elif isinstance(obj, list):
# 递归处理列表中的每个元素
return [convert_decimal(item) for item in obj]
return obj # 返回非 Decimal、dict 或 list 的值
async def get_table_configs(query_db, page_object, user, role_id_list, table_name):
parts = table_name.split(".")
schema, table = parts
# 获取用户的列配置
user_col_list = await MetaSecurityDao.get_api_col_list(
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id
)
# 获取角色的列配置
role_col_list = []
for role_id in role_id_list:
role_cols = await MetaSecurityDao.get_api_col_list(
query_db, page_object.dbRCode, schema,table, '1', role_id
)
role_col_list.extend(role_cols) # 将每个角色的列配置合并到列表中
# 获取用户的行配置
user_row_list = await MetaSecurityDao.get_api_row_list(
query_db, page_object.dbRCode, schema,table, '0', user[0].user_id
)
# 获取角色的行配置
role_row_list = []
for role_id in role_id_list:
role_rows = await MetaSecurityDao.get_api_row_list(
query_db, page_object.dbRCode,schema,table, '1', role_id
)
role_row_list.extend(role_rows) # 将每个角色的行配置合并到列表中
isHave = any([
user_col_list,
role_col_list,
user_row_list,
role_row_list
])
return {
"user_col_list": user_col_list,
"role_col_list": role_col_list,
"role_row_list": role_row_list,
"user_row_list": user_row_list,
"isHave":isHave
}
# async def generate_sql(tablesRowCol:dict, table_columns:dict):
# sql_queries = {}
# # 1. 列控制
# # 遍历每个表
# no_configTable_name=""
# for table_name, table_configs in tablesRowCol.items():
# if not table_configs.get("isHave", False):
# no_configTable_name += table_name + ","
# if no_configTable_name:
# no_configTable_name = no_configTable_name.rstrip(',')
# raise ValueError(f"表:{no_configTable_name}均未配置行列数据安全")
# for table_name, config in tablesRowCol.items():
# # 获取该表的字段名
# columns = {col.lower(): col for col in table_columns[table_name]} # 将字段名转为小写
# # 初始化 SELECT 部分:用字典存储字段名,值是 null 字段名
# select_columns = {col: f"null as {col}" for col in columns}
# # 处理角色列配置
# for col in config["role_col_list"]:
# # If dbCName is "ALL", handle it as a special case
# if col.dbCName == "ALL":
# if col.ctrl_type == '0': # If ctrl_type is '0', prefix all columns with null
# for db_column in columns: # Assuming 'user' is the table name
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == '1': # If ctrl_type is '1', use actual column names
# for db_column in columns:
# select_columns[db_column] = db_column # 使用实际字段名
# else:
# # Handle specific columns listed in dbCName
# db_columns = [db_column.strip().lower() for db_column in col.dbCName.split(",")]
# for db_column in db_columns:
# db_column = db_column.strip()
# if db_column in columns: # Check if the column exists in the table
# if col.ctrl_type == '0': # If ctrl_type is '0', prefix with null
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == '1': # If ctrl_type is '1', use actual column name
# select_columns[db_column] = db_column # 使用实际字段名
# # 处理用户列配置
# for col in config["user_col_list"]:
# if col.dbCName == "ALL": # 如果 dbCName 为 "ALL"
# if col.ctrl_type == "0": # ctrlType 为 0,字符串字段
# for db_column in columns: # 对所有字段加上 null
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == "1": # ctrlType 为 1,实际数据库字段
# for db_column in columns: # 使用实际字段名,不加 null
# select_columns[db_column] = db_column # 使用实际字段名
# else: # 处理 dbCName 不为 "ALL" 的情况
# db_columns = [db_column.strip().lower() for db_column in col.dbCName.split(",")]
# for db_column in db_columns:
# db_column = db_column.strip()
# if db_column in columns:
# if col.ctrl_type == "0":
# select_columns[db_column] = f"null as {db_column}" # 仍然保留 null 前缀
# elif col.ctrl_type == "1":
# select_columns[db_column] = db_column # 使用实际字段名
# # 生成 SQL 查询
# sql_queries[table_name] = f"SELECT {', '.join(select_columns.values())} FROM {table_name}"
# # 2.行控制
# select_rows={}
# # 处理角色行配置
# for row in config["role_row_list"]:
# # 仅仅对固定值有效,不加行限制
# if row.ctrl_value == "ALL" and row.ctrl_type == '0':
# # 控制方式 --固定值
# select_rows[row.dbCName] = ""
# else:
# if row.ctrl_type == '0':
# # row.ctrl_value 是逗号分隔的字符串时,改为 IN 语句
# if "," in row.ctrl_value:
# # 将 ctrl_value 按逗号分割,并用单引号包裹每个值
# values = [f"'{value.strip()}'" for value in row.ctrl_value.split(",")]
# select_rows[row.dbCName] = f"{row.dbCName} IN ({', '.join(values)})"
# else:
# select_rows[row.dbCName] = f"{row.dbCName} = '{row.ctrl_value}'"
# if row.ctrl_type == '1':
# tab_col_value=row.ctrl_value.split(".")
# if len(tab_col_value) != 2:
# raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值")
# select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')"
# # 处理用户行配置
# for row in config["user_row_list"]:
# # 仅仅对固定值有效,不加行限制
# if row.ctrl_value == "ALL" and row.ctrl_type == '0':
# # 控制方式 --固定值
# select_rows[row.dbCName] = ""
# else:
# if row.ctrl_type == '0':
# # row.obj_value 是逗号分隔的字符串时,改为 IN 语句
# if "," in row.ctrl_value:
# # 将 obj_value 按逗号分割,并用单引号包裹每个值
# values = [f"'{value.strip()}'" for value in row.ctrl_value.split(",")]
# select_rows[row.dbCName] = f"{row.dbCName} IN ({', '.join(values)})"
# else:
# select_rows[row.dbCName] = f"{row.dbCName} = '{row.ctrl_value}'"
# if row.ctrl_type == '1':
# tab_col_value=row.ctrl_value.split(".")
# if len(tab_col_value) != 2:
# raise RuntimeError(f"{row.dbCName}字段控制类型为表字段,未维护正确的值")
# select_rows[row.dbCName] = f"{row.dbCName} in (select {tab_col_value[1]} from {row.dbSName}.{tab_col_value[0]} where user_id = '1')"
# if select_rows.values():
# where_conditions = " AND ".join(select_rows.values())
# if where_conditions:
# sql_queries[table_name] += " WHERE " + where_conditions
# else:
# sql_queries[table_name] += " WHERE 1 = 0"
# return sql_queries
async def generate_sql(tablesRowCol: dict, table_columns: dict):
sql_queries = {}
# ========= 0. 校验是否存在未配置安全策略的表 =========
no_config_tables = [
table_name
for table_name, cfg in tablesRowCol.items()
if not cfg.get("isHave", False)
]
if no_config_tables:
raise ValueError(f"您没有查看{','.join(no_config_tables)} 表记录的权限,请联系管理员配置相关行/列 数据安全策略")
# ========= 1. 遍历每个表 =========
for table_name, config in tablesRowCol.items():
# 字段映射:小写 → 原始字段名
columns = {col.lower(): col for col in table_columns[table_name]}
# ====================================================
# 2. 列控制(不可见优先)
# ====================================================
# 0 = 不可见,1 = 可见,None = 未配置(默认不可见)
column_visibility = {col: None for col in columns}
def set_visibility(col_name: str, ctrl_type: str):
"""
不可见(ctrl_type=0) 优先级最高
"""
if ctrl_type == '0':
column_visibility[col_name] = '0'
elif ctrl_type == '1':
if column_visibility[col_name] != '0':
column_visibility[col_name] = '1'
def handle_col_config(col_cfg_list):
for col in col_cfg_list:
if col.dbCName == "ALL":
for db_col in columns:
set_visibility(db_col, col.ctrl_type)
else:
db_cols = [c.strip().lower() for c in col.dbCName.split(",")]
for db_col in db_cols:
if db_col in columns:
set_visibility(db_col, col.ctrl_type)
# 角色列 + 用户列
handle_col_config(config.get("role_col_list", []))
handle_col_config(config.get("user_col_list", []))
# 生成 SELECT 字段
select_columns = []
for col in columns:
if column_visibility[col] == '1':
select_columns.append(col)
else:
select_columns.append(f"null as {col}")
sql = f"SELECT {', '.join(select_columns)} FROM {table_name}"
# ====================================================
# 3. 行控制
# ====================================================
where_conditions = []
allow_all_rows = False
def build_row_condition(row):
# 固定值 & ALL → 不加限制
if row.ctrl_type == '0' and row.ctrl_value == "ALL":
return "__ALLOW_ALL__"
# 固定值
if row.ctrl_type == '0':
if "," in row.ctrl_value:
values = [f"'{v.strip()}'" for v in row.ctrl_value.split(",")]
return f"{row.dbCName} IN ({', '.join(values)})"
return f"{row.dbCName} = '{row.ctrl_value}'"
# 表字段
if row.ctrl_type == '1':
tab_col = row.ctrl_value.split(".")
if len(tab_col) != 2:
raise RuntimeError(
f"{row.dbCName} 字段控制类型为表字段,但未维护正确的值"
)
table, column = tab_col
return (
f"{row.dbCName} IN ("
f"SELECT {column} FROM {row.dbSName}.{table} "
f"WHERE user_id = '1')"
)
return None
def handle_row_config(row_cfg_list):
nonlocal allow_all_rows
for row in row_cfg_list:
condition = build_row_condition(row)
if condition == "__ALLOW_ALL__":
allow_all_rows = True
continue
if condition:
where_conditions.append(condition)
# 角色行 + 用户行
handle_row_config(config.get("role_row_list", []))
handle_row_config(config.get("user_row_list", []))
# ====================================================
# 4. WHERE 拼接(无行权限则拒绝访问)
# ====================================================
if where_conditions:
sql += " WHERE " + " AND ".join(where_conditions)
elif allow_all_rows:
pass # 不拼 WHERE,等
else:
sql += " WHERE 1 = 0"
sql_queries[table_name] = sql
return sql_queries
async def replace_table_with_subquery(ctrSqlDict, oldStrSql):
"""
将 SQL 中的表替换成子查询,并自动生成别名,同时把字段引用替换为别名.字段
"""
table_alias_map = {} # 存储表名和别名的映射
for table_name, subquery in ctrSqlDict.items():
# 1️⃣ 匹配 FROM / JOIN 中的表名及别名(不使用 lookbehind)
from_join_pattern = (
r'\b(FROM|JOIN)\s+' # 捕获关键字
r'((?:[a-zA-Z_][a-zA-Z0-9_]*\.)?' # 模式名(可选)
+ re.escape(table_name) + r')' # 表名
r'(\s+(?:AS\s+)?(\w+))?' # 可选别名
)
# 替换 FROM / JOIN 部分
def from_join_replace(match):
keyword = match.group(1) # FROM / JOIN
original_table = match.group(2)
alias_part = match.group(3) # " AS xxx" 或 " xxx"
alias_name = match.group(4) # xxx
sql_keywords = {
"SELECT", "INSERT", "UPDATE", "DELETE", "MERGE", "TRUNCATE",
"VALUES", "RETURNING", "FROM", "WHERE", "GROUP", "HAVING", "ORDER",
"LIMIT", "OFFSET", "DISTINCT", "ALL", "UNION", "INTERSECT", "EXCEPT",
"JOIN", "INNER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL", "USING", "ON",
"TABLE", "VIEW", "INDEX", "PRIMARY", "KEY", "FOREIGN", "REFERENCES", "NOT",
"NULL", "UNIQUE", "CHECK", "DEFAULT", "IF", "ELSE", "CASE", "WHEN", "THEN",
"END", "LOOP", "FOR", "WHILE", "CREATE", "ALTER", "DROP", "TRUNCATE", "COMMENT",
"EXISTS", "IN", "IS", "LIKE", "ILIKE", "SIMILAR", "BETWEEN", "AND", "OR", "ANY",
"ALL", "SOME", "FETCH", "NEXT", "ONLY", "ASC", "DESC", "GRANT", "REVOKE", "ROLE",
"USER", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP",
}
# 动态获取子查询
if original_table in ctrSqlDict and alias_name and alias_name.upper().split()[0] not in sql_keywords:
# 使用 ctrSqlDict 中的子查询替换表名
replaced = f"{keyword} ({ctrSqlDict[original_table]}) {alias_part}"
else:
# 默认处理逻辑:判断 alias 是否为关键字
if alias_name and alias_name.upper().split()[0] not in sql_keywords:
replaced = f"{keyword} ({subquery}) {alias_part}"
table_alias_map[original_table] = alias_name
else:
alias = original_table.split('.')[-1]
replaced = f"{keyword} ({subquery}) AS {alias}{alias_part or ''}"
table_alias_map[original_table] = alias
return replaced
oldStrSql = re.sub(from_join_pattern, from_join_replace, oldStrSql, flags=re.IGNORECASE)
# 2️⃣ 替换字段引用 table_name.column → alias.column
column_ref_pattern = re.escape(table_name) + r'\.(\w+)'
def column_replace(match):
col = match.group(1)
alias = table_alias_map.get(table_name, table_name.split('.')[-1])
return f"{alias}.{col}"
oldStrSql = re.sub(column_ref_pattern, column_replace, oldStrSql, flags=re.IGNORECASE)
return oldStrSql
async def get_data_source_tree(request: Request, current_user: MetaSecurityApiModel):
url = f'{AppConfig.ds_server_url}/dolphinscheduler/datasources/withpwdlist?pageNo=1&pageSize=100'
headers = {'dashUserName': current_user.username ,'dashPassword': current_user.password,}
response = requests.get(url, headers=headers, verify=False)
if response.reason == 'OK':
response_text = response.text
data = json.loads(response_text)
total_list = data["data"]["totalList"]
# 解析 connectionParams 字符串为字典
for item in total_list:
if item["name"]==current_user.dbRCode:
item["connectionParams"] = json.loads(item["connectionParams"])
return item
raise Exception(f'根据数据源ID:{current_user.dbRCode}获取数据源信息失败,状态: {response.reason}')
else:
raise Exception(f'根据数据源ID:{current_user.dbRCode}获取数据源信息失败,状态: {response.reason}')
async def test_connection(db_content):
try:
# 尝试执行一个简单的查询来测试连接
async with db_content.connect() as connection:
# 这里执行一个简单的查询,例如“SELECT 1”
await connection.scalar("SELECT 1")
except Exception as e:
raise Exception("数据源连接失败") from e
def has_pagination(sql: str, db_type: str) -> bool:
sql_upper = sql.upper()
if db_type in ("MYSQL", "POSTGRESQL"):
return bool(re.search(r"\bLIMIT\b", sql_upper))
if db_type == "SQLSERVER":
return (
"OFFSET" in sql_upper and "FETCH" in sql_upper
) or "ROW_NUMBER" in sql_upper
if db_type == "ORACLE":
return "ROWNUM" in sql_upper
return False
def generate_pagination_sql(page_object: MetaSecurityApiModel, db_type: str) -> str:
"""
生成带分页的 SQL 语句
:param page_object: 包含分页参数的对象
:param db_type: 数据库类型(大写字符串)
:return: 带分页的 SQL 语句
"""
page_num = page_object.pageNum or 1 # 当前页码,默认为 1
page_size = page_object.pageSize or 10 # 每页大小,默认为 10
offset = (page_num - 1) * page_size # 计算偏移量(跳过的行数)
oldStrSql = page_object.sqlStr # 获取原始 SQL 语句
db_type = db_type.upper() # 确保数据库类型为大写
if db_type == "MYSQL" or db_type == "POSTGRESQL":
newStrSql = f"{oldStrSql} LIMIT {page_size} OFFSET {offset}"
elif db_type == "SQLSERVER":
newStrSql = f"{oldStrSql} ORDER BY id OFFSET {offset} ROWS FETCH NEXT {page_size} ROWS ONLY"
elif db_type == "ORACLE":
newStrSql = f"""
SELECT * FROM (
SELECT a.*, ROWNUM rnum FROM (
{oldStrSql} ORDER BY id
) a WHERE ROWNUM <= {offset + page_size}
) WHERE rnum > {offset}
"""
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
return newStrSql
def _extract_identifiers(token):
"""
从 Identifier 或 IdentifierList 中抽取 (schema, table) 对。
返回格式为 'schema.table'(如果有 schema),否则 None。
"""
if isinstance(token, Identifier):
real_name = token.get_real_name() # table
parent_name = token.get_parent_name() # schema if exists
if real_name and parent_name:
return f"{parent_name}.{real_name}"
# 处理像 schema.table AS alias 这种形式
# token.get_name() 返回 alias 或 table,根据需要可扩展
return None
# async def get_tables_from_sql(sql_query: str):
# """
# 使用 sqlparse 解析 SQL 并返回 schema.table 列表(去重)
# 支持嵌套子查询、函数、别名、JOIN、INTO、UPDATE 等。
# 只返回包含 schema 的标识符(即有点号的)。
# """
# parsed = sqlparse.parse(sql_query)
# tables = set()
# for stmt in parsed:
# # 遍历语句的 token 树,寻找顶层的 FROM/JOIN/INTO/UPDATE 子句
# for token in stmt.tokens:
# # 忽略函数、子查询整体(它们会在自己的 stmt 中被处理)
# # 但我们需要遍历整个树以捕获顶层的 Identifier/IdentifierList
# if token.is_group:
# # 递归遍历 group 内的 token
# for t in token.flatten():
# # 跳过在函数内部的 FROM(例如 EXTRACT(... FROM ...))
# # 方法:判断最近的父级 group 是否为 Function(我们用了 flatten,故这里检查 parent types is hard)
# # 简化办法:如果 token 的上层类型是 Function 的一部分,skip(handled by checking surrounding tokens)
# pass
# # 更好的做法:直接按 sqlparse 提供的机制遍历并查找 Identifier/IdentifierList
# for token in stmt.tokens:
# if token.ttype is DML and token.normalized.upper() in ("UPDATE",):
# # UPDATE table_name ...
# # 下一个非空白 token 往往是 Identifier
# nxt = stmt.token_next(stmt.token_index(token), skip_ws=True, skip_cm=True)
# if nxt:
# name = _extract_identifiers(nxt[1]) if isinstance(nxt[1], (Identifier, IdentifierList)) else None
# if name:
# tables.add(name)
# # 使用遍历获取所有 Identifier / IdentifierList 出现在 FROM 或 JOIN 后面的情况
# # 这里遍历 token 序列并在遇到 FROM/JOIN/INTO/UPDATE 时提取后续 identifier
# idx = 0
# tokens = list(stmt.tokens)
# while idx < len(tokens):
# t = tokens[idx]
# if t.is_whitespace:
# idx += 1
# continue
# if t.ttype is Keyword and t.normalized.upper() in ("FROM", "JOIN", "INTO"):
# # 找下一个有意义的 token(可能是 Identifier 或 IdentifierList 或 Parenthesis 表示子查询)
# nxt = stmt.token_next(idx, skip_ws=True, skip_cm=True)
# if nxt:
# tok = nxt[1]
# # 如果是 parenthesis -> 子查询,跳过
# if tok.is_group and isinstance(tok, Function):
# # 函数内的 FROM(如 EXTRACT(...))会被解析为 Function 的一部分 —— 跳过
# pass
# else:
# # 处理 Identifier 或 IdentifierList
# if isinstance(tok, Identifier):
# name = _extract_identifiers(tok)
# if name:
# tables.add(name)
# elif isinstance(tok, IdentifierList):
# for ident in tok.get_identifiers():
# name = _extract_identifiers(ident)
# if name:
# tables.add(name)
# else:
# # 可能是直接的 Name token 'schema.table'(未被识别为 Identifier)
# txt = tok.value
# if "." in txt:
# parts = txt.strip().strip('`"\'').split(".")
# if len(parts) == 2:
# tables.add(f"{parts[0]}.{parts[1]}")
# idx += 1
# continue
# idx += 1
# if not tables:
# raise ValueError("SQL 解析失败,未找到任何 schema.table 结构")
# return sorted(tables)
async def get_tables_from_sql(sql_query: str):
"""
使用 sqlglot 解析 SQL 并返回 schema.table 列表(去重,排序)
支持嵌套子查询、JOIN、别名、UPDATE、INSERT 等
"""
try:
# 解析 SQL(自动支持多语句)
parsed = sqlglot.parse(sql_query)
except Exception as e:
raise ValueError(f"SQL 解析失败: {e}")
tables: Set[str] = set()
def extract_tables(expr):
"""
递归提取 Table 表达式
"""
if isinstance(expr, Table):
# 只取 schema.table
if expr.args.get("db") and expr.args.get("this"):
tables.add(f"{expr.args['db']}.{expr.args['this']}")
# 递归子节点
for child in expr.args.values():
if isinstance(child, list):
for c in child:
if isinstance(c, sqlglot.Expression):
extract_tables(c)
elif isinstance(child, sqlglot.Expression):
extract_tables(child)
for stmt in parsed:
extract_tables(stmt)
if not tables:
raise ValueError("SQL 解析失败,未找到任何 schema.table 结构")
return sorted(tables)