You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
416 lines
16 KiB
416 lines
16 KiB
from utils.request_xinf import api_request_xinf
|
|
from utils.request import api_request
|
|
import json
|
|
import re
|
|
from collections import defaultdict
|
|
|
|
|
|
def get_model_tree_data(extracted_data):
|
|
tree_data = []
|
|
|
|
# 按照 model_type 和 model_family 分组
|
|
grouped_data = defaultdict(lambda: defaultdict(list))
|
|
for item in extracted_data:
|
|
model_type = item['model_type']
|
|
model_family = item['model_family']
|
|
grouped_data[model_type][model_family].append(item)
|
|
|
|
# 构建 treeData 结构
|
|
for model_type, families in grouped_data.items():
|
|
model_type_node = {
|
|
'title': model_type,
|
|
'key': model_type,
|
|
# 'tooltipProps': {
|
|
# 'title': '模型类型描述😀'
|
|
# },
|
|
'children': []
|
|
}
|
|
|
|
for model_family, models in families.items():
|
|
model_family_node = {
|
|
'title': model_family,
|
|
'key': f"{model_type}/{model_family}",
|
|
}
|
|
model_type_node['children'].append(model_family_node)
|
|
|
|
tree_data.append(model_type_node)
|
|
|
|
return tree_data
|
|
|
|
def get_model_flag(item, element):
|
|
if element == "LLM":
|
|
return {
|
|
'model_lang': item.get('model_lang'),
|
|
'model_ability': item.get('model_ability'),
|
|
'context_length': item.get('context_length')
|
|
}
|
|
elif element == "embedding":
|
|
return {
|
|
'model_lang': item.get('language'),
|
|
'model_dims': item.get('dimensions'),
|
|
'model_mtks': item.get('max_tokens')
|
|
}
|
|
elif element == "rerank":
|
|
return {
|
|
'model_lang': item.get('language'),
|
|
'model_tp': item.get('type')
|
|
}
|
|
elif element == "image":
|
|
return {
|
|
'model_ability': item.get('ability'),
|
|
'model_tp': item.get('type')
|
|
}
|
|
elif element == "audio":
|
|
if item.get('multilingual') is True:
|
|
return {'model_lang': "mul"}
|
|
else:
|
|
return {}
|
|
return {}
|
|
|
|
def clean_model_family(family):
|
|
# 提取路径部分的第一个元素,移除前缀中的数字并转为小写
|
|
family = family.split('/')[0]
|
|
family = re.sub(r'\d+$', '', family)
|
|
family = family.lower()
|
|
|
|
# 如果识别为 glm,则转换为 chatglm
|
|
if family == 'glm' or family == 'codegeex' or family == 'chatglm':
|
|
family = 'zhipu'
|
|
|
|
if family == 'starcoderplus':
|
|
family = 'starcoder'
|
|
|
|
if family == 'wizardlm' or family == 'wizardmath' or family == 'wizardcoder':
|
|
family = 'wizard'
|
|
# 移除尾部的标点符号
|
|
# family = re.sub(r'[^\w\s\d]$', '', family) # 移除尾部的标点符号和数字
|
|
# family = re.sub(r'[^\w\s\d]$', '', family) # 移除尾部的标点符号
|
|
|
|
return family
|
|
|
|
def extract_data(element, data):
|
|
extracted_data = []
|
|
for item in data:
|
|
model_flag = get_model_flag(item, element)
|
|
|
|
# 获取 model_family 的值并进行清洗
|
|
model_family = item.get('model_family')
|
|
if not model_family:
|
|
model_id = item.get('model_id', '')
|
|
model_name = item.get('model_name', '')
|
|
|
|
if model_id:
|
|
model_family = model_id.split('\/')[0]
|
|
elif model_name:
|
|
# model_family = model_name.split('-')[0]
|
|
model_family = re.split('[-_]', model_name)[0]
|
|
|
|
if model_family:
|
|
model_family = clean_model_family(model_family)
|
|
|
|
# 检查 model_name 并根据需要更新 model_family
|
|
model_name = item.get('model_name', '').lower()
|
|
if 'llama' in model_name:
|
|
model_family = 'llama'
|
|
elif 'qwen' in model_name:
|
|
model_family = 'qwen'
|
|
|
|
|
|
index = len(model_family) - 1
|
|
|
|
# 从字符串末尾开始遍历
|
|
while index >= 0 and (model_family[index].isdigit() or re.match(r'[^\w\s]', model_family[index])):
|
|
index -= 1
|
|
|
|
# 获取新的字符串
|
|
model_family = model_family[:index + 1]
|
|
# while s and (s[-1].isdigit() or re.match(r'[^\w\s]', s[-1])):
|
|
# s = s[:-1]
|
|
# return s
|
|
# model_family = re.sub(r'[^\w\s\d]$', '', model_family) # 移除尾部的标点符号和数字
|
|
|
|
# 使用get方法获取model_description,确保即使键不存在也不会报错
|
|
model_description = item.get('model_description', '') or item.get('model_id', '')
|
|
|
|
extracted = {
|
|
'model_type': element,
|
|
'model_family': model_family,
|
|
'model_name': item.get('model_name'),
|
|
'model_flag': model_flag,
|
|
'model_description': model_description,
|
|
}
|
|
extracted_data.append(extracted)
|
|
return extracted_data
|
|
|
|
def get_models_detail_api(model_type_list: list):
|
|
extracted_data = []
|
|
for element in model_type_list:
|
|
if element in ["LLM", "embedding", "rerank", "image", "audio"]:
|
|
url = f'model_registrations/{element}?detailed=true'
|
|
data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
extracted_data.extend(extract_data(element, data))
|
|
|
|
# with open('datanew3.txt', 'w', encoding='utf-8') as text_file:
|
|
# text_file.write(json.dumps(extracted_data, ensure_ascii=False, indent=4))
|
|
tree_data = get_model_tree_data(extracted_data)
|
|
return tree_data,extracted_data
|
|
|
|
def get_model_info_api(model_name: str):
|
|
|
|
#从后台服务获得model_info大部分信息
|
|
|
|
#从xinference获得prompt的信息
|
|
url = f'models/prompts'
|
|
re_data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
# print("prompt信息",re_data,type(re_data))
|
|
print("model_name信息",model_name)
|
|
# re_data_dict = json.loads(re_data)
|
|
prompt_data = re_data.get(model_name, "模型prompt不存在")
|
|
print("prompt_data1",prompt_data)
|
|
model_info = {
|
|
"model_name_cn": "通义千问2-开源版",
|
|
"model_name_en": model_name,
|
|
"model_note": "通义千问2对外开源的模型,模型支持 131,072 tokens上下文,为了保障正常使用和正常输出,API限定用户输入为 128,000 ,输出最大 6,144。",
|
|
"model_fune_flag": "Y",
|
|
"model_is_open_source": "是",
|
|
"model_release_date": "2024-08-08",
|
|
"model_website_url": "https://modelscope.cn/organization/qwen?tab=model",
|
|
"model_prmopt":str(prompt_data)
|
|
}
|
|
return model_info
|
|
|
|
def get_model_info_extend_api(model_name: str):
|
|
curl_code = '''
|
|
curl --location 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation' \
|
|
--header 'Authorization: Bearer <your-dashscope-api-key>' \
|
|
--header 'Content-Type: application/json' \
|
|
--header 'X-DashScope-SSE: enable' \
|
|
--data '{
|
|
"model": "qwen-max",
|
|
"input":{
|
|
"messages":[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Introduce the capital of China"
|
|
}
|
|
]
|
|
},
|
|
"parameters": {
|
|
}
|
|
}'''
|
|
python_code = '''#
|
|
from http import HTTPStatus
|
|
import dashscope
|
|
|
|
def call_with_stream():
|
|
messages = [
|
|
{'role': 'user', 'content': 'Introduce the capital of China'}]
|
|
responses = dashscope.Generation.call("qwen-max",
|
|
messages=messages,
|
|
result_format='message', # set the result to be "message" format.
|
|
stream=True, # set streaming output
|
|
incremental_output=True # get streaming output incrementally
|
|
)
|
|
for response in responses:
|
|
if response.status_code == HTTPStatus.OK:
|
|
print(response.output.choices[0]['message']['content'],end='')
|
|
else:
|
|
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
|
response.request_id, response.status_code,
|
|
response.code, response.message
|
|
))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
call_with_stream()'''
|
|
java_code = '''
|
|
// Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import java.util.Arrays;
|
|
import java.util.concurrent.Semaphore;
|
|
import com.alibaba.dashscope.aigc.generation.Generation;
|
|
import com.alibaba.dashscope.aigc.generation.GenerationParam;
|
|
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
|
import com.alibaba.dashscope.common.Message;
|
|
import com.alibaba.dashscope.common.ResultCallback;
|
|
import com.alibaba.dashscope.common.Role;
|
|
import com.alibaba.dashscope.exception.ApiException;
|
|
import com.alibaba.dashscope.exception.InputRequiredException;
|
|
import com.alibaba.dashscope.exception.NoApiKeyException;
|
|
import com.alibaba.dashscope.protocol.Protocol;
|
|
import com.alibaba.dashscope.utils.JsonUtils;
|
|
import io.reactivex.Flowable;
|
|
|
|
public class Main {
|
|
public static void streamCallWithMessage()
|
|
throws NoApiKeyException, ApiException, InputRequiredException {
|
|
Generation gen = new Generation();
|
|
Message userMsg =
|
|
Message.builder().role(Role.USER.getValue()).content("Introduce the capital of China").build();
|
|
GenerationParam param = GenerationParam.builder()
|
|
.model("qwen-max")
|
|
.messages(Arrays.asList(userMsg))
|
|
.resultFormat(GenerationParam.ResultFormat.MESSAGE) // the result if message format.
|
|
.topP(0.8).enableSearch(true) // set streaming output
|
|
.incrementalOutput(true) // get streaming output incrementally
|
|
.build();
|
|
Flowable<GenerationResult> result = gen.streamCall(param);
|
|
StringBuilder fullContent = new StringBuilder();
|
|
result.blockingForEach(message -> {
|
|
fullContent.append(message.getOutput().getChoices().get(0).getMessage().getContent());
|
|
System.out.println(JsonUtils.toJson(message));
|
|
});
|
|
System.out.println("Full content: \n" + fullContent.toString());
|
|
}
|
|
|
|
public static void streamCallWithCallback()
|
|
throws NoApiKeyException, ApiException, InputRequiredException, InterruptedException {
|
|
Generation gen = new Generation();
|
|
Message userMsg =
|
|
Message.builder().role(Role.USER.getValue()).content("Introduce the capital of China").build();
|
|
GenerationParam param = GenerationParam.builder()
|
|
.model("${modelCode}")
|
|
.resultFormat(GenerationParam.ResultFormat.MESSAGE) //set result format message
|
|
.messages(Arrays.asList(userMsg)) // set messages
|
|
.topP(0.8)
|
|
.incrementalOutput(true) // set streaming output incrementally
|
|
.build();
|
|
Semaphore semaphore = new Semaphore(0);
|
|
StringBuilder fullContent = new StringBuilder();
|
|
gen.streamCall(param, new ResultCallback<GenerationResult>() {
|
|
|
|
@Override
|
|
public void onEvent(GenerationResult message) {
|
|
fullContent
|
|
.append(message.getOutput().getChoices().get(0).getMessage().getContent());
|
|
System.out.println(message);
|
|
}
|
|
|
|
@Override
|
|
public void onError(Exception err) {
|
|
System.out.println(String.format("Exception: %s", err.getMessage()));
|
|
semaphore.release();
|
|
}
|
|
|
|
@Override
|
|
public void onComplete() {
|
|
System.out.println("Completed");
|
|
semaphore.release();
|
|
}
|
|
|
|
});
|
|
semaphore.acquire();
|
|
System.out.println("Full content: \n" + fullContent.toString());
|
|
}
|
|
|
|
public static void main(String[] args) {
|
|
try {
|
|
streamCallWithMessage();
|
|
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
|
|
System.out.println(e.getMessage());
|
|
}
|
|
try {
|
|
streamCallWithCallback();
|
|
} catch (ApiException | NoApiKeyException | InputRequiredException
|
|
| InterruptedException e) {
|
|
System.out.println(e.getMessage());
|
|
}
|
|
System.exit(0);
|
|
}
|
|
}
|
|
'''
|
|
mode_info_extend = {
|
|
# "model_name_cn": "通义千问2-开源版",
|
|
"model_oneid":"001",
|
|
"model_name_en": model_name,
|
|
"curl_code": curl_code,
|
|
"python_code": python_code,
|
|
"java_code": java_code,
|
|
"model_update_date": "2024-08-08",
|
|
}
|
|
return mode_info_extend
|
|
# return api_request(method='get', url=f'/modmag/model/{model_name}', is_headers=True)
|
|
|
|
#从后台服务获得model_info大部分信息
|
|
def get_model_engine_api(model_name: str):
|
|
url = f'engines/{model_name}'
|
|
re_data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
print("model_name信息",model_name)
|
|
return re_data
|
|
|
|
#从后台服务获得cluster大部分信息
|
|
def get_cluster_info_api( ):
|
|
url = f'cluster/info?detailed=true'
|
|
re_data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
print("cluster信息",re_data)
|
|
return re_data
|
|
|
|
def launch_model_api(params: dict):
|
|
url = f'models'
|
|
re = api_request_xinf(method='post', url=url, is_headers=True, json=params)
|
|
return re
|
|
|
|
def get_model_run_api( ):
|
|
url = f'models/instances'
|
|
data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
print("data:",data)
|
|
mode_llm_run_data = []
|
|
mode_embed_run_data = []
|
|
mode_rerank_run_data = []
|
|
for item in data:
|
|
url = f'models/{item["model_uid"]}'
|
|
print(item['model_uid'])
|
|
detail_data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
print("detail_data为:",detail_data)
|
|
if detail_data["model_type"] == "LLM":
|
|
mode_llm_run_data.append({
|
|
'model_uid': item["model_uid"],
|
|
'model_name': detail_data["model_name"],
|
|
'worker_address': detail_data["address"],
|
|
'GPU_indexs': str(detail_data["accelerators"]),
|
|
'model_size': detail_data["model_size_in_billions"],
|
|
'model_quant': detail_data["quantization"],
|
|
'model_replica': detail_data["replica"],
|
|
|
|
})
|
|
if detail_data["model_type"] == "embedding":
|
|
mode_embed_run_data.append({
|
|
'model_uid': item["model_uid"],
|
|
'model_name': detail_data["model_name"],
|
|
'worker_address': detail_data["address"],
|
|
'GPU_indexs': str(detail_data["accelerators"]),
|
|
# 'model_size': detail_data["model_size_in_billions"],
|
|
# 'model_quant': detail_data["quantization"],
|
|
'model_replica': detail_data["replica"],
|
|
})
|
|
if detail_data["model_type"] == "rerank":
|
|
mode_rerank_run_data.append({
|
|
'model_uid': item["model_uid"],
|
|
'model_name': detail_data["model_name"],
|
|
'worker_address': detail_data["address"],
|
|
'GPU_indexs': str(detail_data["accelerators"]),
|
|
# 'model_size': detail_data["model_size_in_billions"],
|
|
# 'model_quant': detail_data["quantization"],
|
|
'model_replica': detail_data["replica"],
|
|
|
|
})
|
|
print("mode_llm_run_data:",mode_llm_run_data)
|
|
print("mode_embed_run_data:",mode_embed_run_data)
|
|
print("mode_rerank_run_data:",mode_rerank_run_data)
|
|
|
|
return mode_llm_run_data,mode_embed_run_data,mode_rerank_run_data
|
|
# model_run_data = []
|
|
# for element in model_type_list:
|
|
# if element in ["LLM", "embedding", "rerank", "image", "audio"]:
|
|
# url = f'models/instances'
|
|
# data = api_request_xinf(method='get', url=url, is_headers=False)
|
|
# extracted_data.extend(extract_data(element, data))
|
|
|
|
# with open('datanew3.txt', 'w', encoding='utf-8') as text_file:
|
|
# text_file.write(json.dumps(extracted_data, ensure_ascii=False, indent=4))
|
|
# tree_data = get_model_tree_data(extracted_data)
|
|
# return tree_data,extracted_data
|