|
|
|
@ -23,6 +23,7 @@ from sqlglot.expressions import Table |
|
|
|
from sqlglot import exp ,parse_one |
|
|
|
from typing import Set |
|
|
|
from sqlparse.tokens import Keyword, DML |
|
|
|
from urllib.parse import quote_plus |
|
|
|
class MetaSecurityService: |
|
|
|
""" |
|
|
|
数据源安全管理模块服务层 |
|
|
|
@ -379,7 +380,9 @@ class MetaSecurityService: |
|
|
|
# 1️⃣ 去掉 jdbc 前缀 |
|
|
|
jdbc_prefixes = { |
|
|
|
"jdbc:mysql://": len("jdbc:mysql://"), |
|
|
|
"jdbc:postgresql://": len("jdbc:postgresql://") |
|
|
|
"jdbc:postgresql://": len("jdbc:postgresql://"), |
|
|
|
"jdbc:oracle:thin:@//": len("jdbc:oracle:thin:@//"), |
|
|
|
"jdbc:oracle:thin:@": len("jdbc:oracle:thin:@") |
|
|
|
} |
|
|
|
for prefix, length in jdbc_prefixes.items(): |
|
|
|
if address.startswith(prefix): |
|
|
|
@ -410,6 +413,21 @@ class MetaSecurityService: |
|
|
|
pool_pre_ping=True, |
|
|
|
connect_args={"timeout": 5} # ⭐ 关键 |
|
|
|
) |
|
|
|
elif db_type.lower() == "oracle": |
|
|
|
address = db_params["address"].lstrip("/") |
|
|
|
user = quote_plus(db_params["user"]) |
|
|
|
password = quote_plus(db_params["password"]) |
|
|
|
connect_type = (db_params.get("connectType") or "").upper() |
|
|
|
service_or_sid = quote_plus(db_params.get("database", "")) |
|
|
|
if connect_type == "ORACLE_SID": |
|
|
|
conn_str = f"oracle+oracledb://{user}:{password}@{address}/?sid={service_or_sid}" |
|
|
|
else: |
|
|
|
conn_str = f"oracle+oracledb://{user}:{password}@{address}/?service_name={service_or_sid}" |
|
|
|
engine = create_async_engine( |
|
|
|
conn_str, |
|
|
|
pool_pre_ping=True, |
|
|
|
connect_args={"transport_connect_timeout": 5} |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError("不支持的数据库类型") |
|
|
|
@ -438,6 +456,7 @@ class MetaSecurityService: |
|
|
|
|
|
|
|
try: |
|
|
|
async with async_session() as session: |
|
|
|
if (db_type or "").upper() == "POSTGRESQL": |
|
|
|
await session.execute(text("SET statement_timeout = 30000")) |
|
|
|
# ⭐ 原始数量 |
|
|
|
if sql_type == "原始结果": |
|
|
|
@ -474,6 +493,7 @@ class MetaSecurityService: |
|
|
|
if select: |
|
|
|
select.set("order", None) |
|
|
|
select.set("limit", None) |
|
|
|
select.set("offset", None) |
|
|
|
|
|
|
|
cleaned_sql = parsed.sql(dialect=dialect) |
|
|
|
return f"SELECT COUNT(*) AS cnt FROM ({cleaned_sql}) t" |
|
|
|
@ -495,10 +515,18 @@ class MetaSecurityService: |
|
|
|
elif db_type.lower() == "mysql": |
|
|
|
# MySQL: 直接查询表字段(MySQL 没有 schema 的概念) |
|
|
|
query = f""" |
|
|
|
SELECT COLUMN_NAME |
|
|
|
SELECT COLUMN_NAME AS column_name |
|
|
|
FROM INFORMATION_SCHEMA.COLUMNS |
|
|
|
WHERE TABLE_NAME = '{table}' |
|
|
|
""" |
|
|
|
elif db_type.lower() == "oracle": |
|
|
|
query = f""" |
|
|
|
SELECT COLUMN_NAME AS "column_name" |
|
|
|
FROM ALL_TAB_COLUMNS |
|
|
|
WHERE OWNER = UPPER('{schema}') |
|
|
|
AND TABLE_NAME = UPPER('{table}') |
|
|
|
ORDER BY COLUMN_ID |
|
|
|
""" |
|
|
|
else: |
|
|
|
raise ValueError(f"暂不支持数据库类型: {db_type}") |
|
|
|
|
|
|
|
|