You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

416 lines
16 KiB

1 month ago
from utils.request_xinf import api_request_xinf
from utils.request import api_request
import json
import re
from collections import defaultdict
def get_model_tree_data(extracted_data):
tree_data = []
# 按照 model_type 和 model_family 分组
grouped_data = defaultdict(lambda: defaultdict(list))
for item in extracted_data:
model_type = item['model_type']
model_family = item['model_family']
grouped_data[model_type][model_family].append(item)
# 构建 treeData 结构
for model_type, families in grouped_data.items():
model_type_node = {
'title': model_type,
'key': model_type,
# 'tooltipProps': {
# 'title': '模型类型描述😀'
# },
'children': []
}
for model_family, models in families.items():
model_family_node = {
'title': model_family,
'key': f"{model_type}/{model_family}",
}
model_type_node['children'].append(model_family_node)
tree_data.append(model_type_node)
return tree_data
def get_model_flag(item, element):
if element == "LLM":
return {
'model_lang': item.get('model_lang'),
'model_ability': item.get('model_ability'),
'context_length': item.get('context_length')
}
elif element == "embedding":
return {
'model_lang': item.get('language'),
'model_dims': item.get('dimensions'),
'model_mtks': item.get('max_tokens')
}
elif element == "rerank":
return {
'model_lang': item.get('language'),
'model_tp': item.get('type')
}
elif element == "image":
return {
'model_ability': item.get('ability'),
'model_tp': item.get('type')
}
elif element == "audio":
if item.get('multilingual') is True:
return {'model_lang': "mul"}
else:
return {}
return {}
def clean_model_family(family):
# 提取路径部分的第一个元素,移除前缀中的数字并转为小写
family = family.split('/')[0]
family = re.sub(r'\d+$', '', family)
family = family.lower()
# 如果识别为 glm,则转换为 chatglm
if family == 'glm' or family == 'codegeex' or family == 'chatglm':
family = 'zhipu'
if family == 'starcoderplus':
family = 'starcoder'
if family == 'wizardlm' or family == 'wizardmath' or family == 'wizardcoder':
family = 'wizard'
# 移除尾部的标点符号
# family = re.sub(r'[^\w\s\d]$', '', family) # 移除尾部的标点符号和数字
# family = re.sub(r'[^\w\s\d]$', '', family) # 移除尾部的标点符号
return family
def extract_data(element, data):
extracted_data = []
for item in data:
model_flag = get_model_flag(item, element)
# 获取 model_family 的值并进行清洗
model_family = item.get('model_family')
if not model_family:
model_id = item.get('model_id', '')
model_name = item.get('model_name', '')
if model_id:
model_family = model_id.split('\/')[0]
elif model_name:
# model_family = model_name.split('-')[0]
model_family = re.split('[-_]', model_name)[0]
if model_family:
model_family = clean_model_family(model_family)
# 检查 model_name 并根据需要更新 model_family
model_name = item.get('model_name', '').lower()
if 'llama' in model_name:
model_family = 'llama'
elif 'qwen' in model_name:
model_family = 'qwen'
index = len(model_family) - 1
# 从字符串末尾开始遍历
while index >= 0 and (model_family[index].isdigit() or re.match(r'[^\w\s]', model_family[index])):
index -= 1
# 获取新的字符串
model_family = model_family[:index + 1]
# while s and (s[-1].isdigit() or re.match(r'[^\w\s]', s[-1])):
# s = s[:-1]
# return s
# model_family = re.sub(r'[^\w\s\d]$', '', model_family) # 移除尾部的标点符号和数字
# 使用get方法获取model_description,确保即使键不存在也不会报错
model_description = item.get('model_description', '') or item.get('model_id', '')
extracted = {
'model_type': element,
'model_family': model_family,
'model_name': item.get('model_name'),
'model_flag': model_flag,
'model_description': model_description,
}
extracted_data.append(extracted)
return extracted_data
def get_models_detail_api(model_type_list: list):
extracted_data = []
for element in model_type_list:
if element in ["LLM", "embedding", "rerank", "image", "audio"]:
url = f'model_registrations/{element}?detailed=true'
data = api_request_xinf(method='get', url=url, is_headers=False)
extracted_data.extend(extract_data(element, data))
# with open('datanew3.txt', 'w', encoding='utf-8') as text_file:
# text_file.write(json.dumps(extracted_data, ensure_ascii=False, indent=4))
tree_data = get_model_tree_data(extracted_data)
return tree_data,extracted_data
def get_model_info_api(model_name: str):
#从后台服务获得model_info大部分信息
#从xinference获得prompt的信息
url = f'models/prompts'
re_data = api_request_xinf(method='get', url=url, is_headers=False)
# print("prompt信息",re_data,type(re_data))
print("model_name信息",model_name)
# re_data_dict = json.loads(re_data)
prompt_data = re_data.get(model_name, "模型prompt不存在")
print("prompt_data1",prompt_data)
model_info = {
"model_name_cn": "通义千问2-开源版",
"model_name_en": model_name,
"model_note": "通义千问2对外开源的模型,模型支持 131,072 tokens上下文,为了保障正常使用和正常输出,API限定用户输入为 128,000 ,输出最大 6,144。",
"model_fune_flag": "Y",
"model_is_open_source": "",
"model_release_date": "2024-08-08",
"model_website_url": "https://modelscope.cn/organization/qwen?tab=model",
"model_prmopt":str(prompt_data)
}
return model_info
def get_model_info_extend_api(model_name: str):
curl_code = '''
curl --location 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation' \
--header 'Authorization: Bearer <your-dashscope-api-key>' \
--header 'Content-Type: application/json' \
--header 'X-DashScope-SSE: enable' \
--data '{
"model": "qwen-max",
"input":{
"messages":[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Introduce the capital of China"
}
]
},
"parameters": {
}
}'''
python_code = '''#
from http import HTTPStatus
import dashscope
def call_with_stream():
messages = [
{'role': 'user', 'content': 'Introduce the capital of China'}]
responses = dashscope.Generation.call("qwen-max",
messages=messages,
result_format='message', # set the result to be "message" format.
stream=True, # set streaming output
incremental_output=True # get streaming output incrementally
)
for response in responses:
if response.status_code == HTTPStatus.OK:
print(response.output.choices[0]['message']['content'],end='')
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
if __name__ == '__main__':
call_with_stream()'''
java_code = '''
// Copyright (c) Alibaba, Inc. and its affiliates.
import java.util.Arrays;
import java.util.concurrent.Semaphore;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.protocol.Protocol;
import com.alibaba.dashscope.utils.JsonUtils;
import io.reactivex.Flowable;
public class Main {
public static void streamCallWithMessage()
throws NoApiKeyException, ApiException, InputRequiredException {
Generation gen = new Generation();
Message userMsg =
Message.builder().role(Role.USER.getValue()).content("Introduce the capital of China").build();
GenerationParam param = GenerationParam.builder()
.model("qwen-max")
.messages(Arrays.asList(userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE) // the result if message format.
.topP(0.8).enableSearch(true) // set streaming output
.incrementalOutput(true) // get streaming output incrementally
.build();
Flowable<GenerationResult> result = gen.streamCall(param);
StringBuilder fullContent = new StringBuilder();
result.blockingForEach(message -> {
fullContent.append(message.getOutput().getChoices().get(0).getMessage().getContent());
System.out.println(JsonUtils.toJson(message));
});
System.out.println("Full content: \n" + fullContent.toString());
}
public static void streamCallWithCallback()
throws NoApiKeyException, ApiException, InputRequiredException, InterruptedException {
Generation gen = new Generation();
Message userMsg =
Message.builder().role(Role.USER.getValue()).content("Introduce the capital of China").build();
GenerationParam param = GenerationParam.builder()
.model("${modelCode}")
.resultFormat(GenerationParam.ResultFormat.MESSAGE) //set result format message
.messages(Arrays.asList(userMsg)) // set messages
.topP(0.8)
.incrementalOutput(true) // set streaming output incrementally
.build();
Semaphore semaphore = new Semaphore(0);
StringBuilder fullContent = new StringBuilder();
gen.streamCall(param, new ResultCallback<GenerationResult>() {
@Override
public void onEvent(GenerationResult message) {
fullContent
.append(message.getOutput().getChoices().get(0).getMessage().getContent());
System.out.println(message);
}
@Override
public void onError(Exception err) {
System.out.println(String.format("Exception: %s", err.getMessage()));
semaphore.release();
}
@Override
public void onComplete() {
System.out.println("Completed");
semaphore.release();
}
});
semaphore.acquire();
System.out.println("Full content: \n" + fullContent.toString());
}
public static void main(String[] args) {
try {
streamCallWithMessage();
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
System.out.println(e.getMessage());
}
try {
streamCallWithCallback();
} catch (ApiException | NoApiKeyException | InputRequiredException
| InterruptedException e) {
System.out.println(e.getMessage());
}
System.exit(0);
}
}
'''
mode_info_extend = {
# "model_name_cn": "通义千问2-开源版",
"model_oneid":"001",
"model_name_en": model_name,
"curl_code": curl_code,
"python_code": python_code,
"java_code": java_code,
"model_update_date": "2024-08-08",
}
return mode_info_extend
# return api_request(method='get', url=f'/modmag/model/{model_name}', is_headers=True)
#从后台服务获得model_info大部分信息
def get_model_engine_api(model_name: str):
url = f'engines/{model_name}'
re_data = api_request_xinf(method='get', url=url, is_headers=False)
print("model_name信息",model_name)
return re_data
#从后台服务获得cluster大部分信息
def get_cluster_info_api( ):
url = f'cluster/info?detailed=true'
re_data = api_request_xinf(method='get', url=url, is_headers=False)
print("cluster信息",re_data)
return re_data
def launch_model_api(params: dict):
url = f'models'
re = api_request_xinf(method='post', url=url, is_headers=True, json=params)
return re
def get_model_run_api( ):
url = f'models/instances'
data = api_request_xinf(method='get', url=url, is_headers=False)
print("data:",data)
mode_llm_run_data = []
mode_embed_run_data = []
mode_rerank_run_data = []
for item in data:
url = f'models/{item["model_uid"]}'
print(item['model_uid'])
detail_data = api_request_xinf(method='get', url=url, is_headers=False)
print("detail_data为:",detail_data)
if detail_data["model_type"] == "LLM":
mode_llm_run_data.append({
'model_uid': item["model_uid"],
'model_name': detail_data["model_name"],
'worker_address': detail_data["address"],
'GPU_indexs': str(detail_data["accelerators"]),
'model_size': detail_data["model_size_in_billions"],
'model_quant': detail_data["quantization"],
'model_replica': detail_data["replica"],
})
if detail_data["model_type"] == "embedding":
mode_embed_run_data.append({
'model_uid': item["model_uid"],
'model_name': detail_data["model_name"],
'worker_address': detail_data["address"],
'GPU_indexs': str(detail_data["accelerators"]),
# 'model_size': detail_data["model_size_in_billions"],
# 'model_quant': detail_data["quantization"],
'model_replica': detail_data["replica"],
})
if detail_data["model_type"] == "rerank":
mode_rerank_run_data.append({
'model_uid': item["model_uid"],
'model_name': detail_data["model_name"],
'worker_address': detail_data["address"],
'GPU_indexs': str(detail_data["accelerators"]),
# 'model_size': detail_data["model_size_in_billions"],
# 'model_quant': detail_data["quantization"],
'model_replica': detail_data["replica"],
})
print("mode_llm_run_data:",mode_llm_run_data)
print("mode_embed_run_data:",mode_embed_run_data)
print("mode_rerank_run_data:",mode_rerank_run_data)
return mode_llm_run_data,mode_embed_run_data,mode_rerank_run_data
# model_run_data = []
# for element in model_type_list:
# if element in ["LLM", "embedding", "rerank", "image", "audio"]:
# url = f'models/instances'
# data = api_request_xinf(method='get', url=url, is_headers=False)
# extracted_data.extend(extract_data(element, data))
# with open('datanew3.txt', 'w', encoding='utf-8') as text_file:
# text_file.write(json.dumps(extracted_data, ensure_ascii=False, indent=4))
# tree_data = get_model_tree_data(extracted_data)
# return tree_data,extracted_data