自定义大模型
在 AI 对话互动场景下,你可能需要使用自定义的大语言模型 (Custom LLM) 实现更多个性化需求。本文介绍如何使用自定义大模型接入声网对话式 AI 引擎。
技术原理
声网对话式 AI 引擎使用 OpenAI API 协议与 LLM 服务进行交互。接入自定义大模型的核心是提供一个与 OpenAI API 兼容的 HTTP 服务,这个服务需要能够接收和返回符合 OpenAI API 协议的请求和响应。
在此基础上,你可以实现更多的自定义功能,包括但不限于:
- 使用检索增强生成(RAG)功能,让大模型从特定知识库中检索信息
- 使用多模态能力,让大模型以文本和音频两种模态输出
- 使用工具调用功能,让大模型调用外部工具
- 使用 Function Calling 功能,让大模型以函数调用的形式返回结构化数据
前提条件
开始前,请确保你已经:
- 参考实现对话式智能体实现了与 AI 智能体对话互动的基本逻辑
- 拥有可访问的自定义大语言模型服务
- 如需使用检索增强生成(RAG)功能,已准备好向量数据库或检索系统
实现方式
创建符合 OpenAI API 协议的服务
要成功接入声网对话式 AI 引擎,你的自定义大模型服务必须提供一个与 OpenAI Chat Completions API 兼容的接口。关键点如下:
- 接口路径:提供接收请求的端点 (Endpoint),例如
https://your-custom-llm-service/chat/completions
。 - 请求格式:接受与 OpenAI API 协议兼容的请求参数。
- 响应格式:返回与 OpenAI API 协议兼容、且符合 SSE 规范的流式响应。
下面的示例代码展示了如何实现一个符合 OpenAI API 协议的接口:
- Python
- Go
class TextContent(BaseModel):
type: str = "text"
text: str
class ImageContent(BaseModel):
type: str = "image"
image_url: HttpUrl
class AudioContent(BaseModel):
type: str = "input_audio"
input_audio: Dict[str, str]
class ToolFunction(BaseModel):
name: str
description: Optional[str]
parameters: Optional[Dict]
strict: bool = False
class Tool(BaseModel):
type: str = "function"
function: ToolFunction
class ToolChoice(BaseModel):
type: str = "function"
function: Optional[Dict]
class ResponseFormat(BaseModel):
type: str = "json_schema"
json_schema: Optional[Dict[str, str]]
class SystemMessage(BaseModel):
role: str = "system"
content: Union[str, List[str]]
class UserMessage(BaseModel):
role: str = "user"
content: Union[str, List[Union[TextContent, ImageContent, AudioContent]]]
class AssistantMessage(BaseModel):
role: str = "assistant"
content: Union[str, List[TextContent]] = None
audio: Optional[Dict[str, str]] = None
tool_calls: Optional[List[Dict]] = None
class ToolMessage(BaseModel):
role: str = "tool"
content: Union[str, List[str]]
tool_call_id: str
# 定义完整的请求格式
class ChatCompletionRequest(BaseModel):
context: Optional[Dict] = None # 上下文信息
model: Optional[str] = None # 使用的模型名称
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]] # 消息列表
response_format: Optional[ResponseFormat] = None # 响应格式
modalities: List[str] = ["text"] # 默认使用文本模态
audio: Optional[Dict[str, str]] = None # 助手的音频回复
tools: Optional[List[Tool]] = None # 工具列表
tool_choice: Optional[Union[str, ToolChoice]] = "auto" # 工具选择
parallel_tool_calls: bool = True # 是否并行调用工具
stream: bool = True # 默认使用流式响应
stream_options: Optional[Dict] = None # 流式选项
@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
try:
logger.info(f"Received request: {request.model_dump_json()}")
client = AsyncOpenAI(api_key=os.getenv("YOUR_LLM_API_KEY"))
response = await client.chat.completions.create(
model=request.model,
messages=request.messages, # 直接使用请求消息
tool_choice=(
request.tool_choice if request.tools and request.tool_choice else None
),
tools=request.tools if request.tools else None,
modalities=request.modalities,
audio=request.audio,
response_format=request.response_format,
stream=request.stream,
stream_options=request.stream_options,
)
if not request.stream:
raise HTTPException(
status_code=400, detail="chat completions require streaming"
)
async def generate():
try:
async for chunk in response:
logger.debug(f"Received chunk: {chunk}")
yield f"data: {json.dumps(chunk.to_dict())}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
logger.info("Request was cancelled")
raise
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("Request was cancelled")
raise HTTPException(status_code=499, detail="Request was cancelled")
except Exception as e:
traceback_str = "".join(traceback.format_tb(e.__traceback__))
error_message = f"{str(e)}\n{traceback_str}"
logger.error(error_message)
raise HTTPException(status_code=500, detail=error_message)
type (
AudioContent struct {
InputAudio map[string]string `json:"input_audio"`
Type string `json:"type"`
}
// 完整请求格式
ChatCompletionRequest struct {
// 助手的音频回复
Audio map[string]string `json:"audio,omitempty"`
// 上下文信息
Context map[string]any `json:"context,omitempty"`
// 消息列表
Messages []Message `json:"messages"`
// 默认使用文本模态
Modalities []string `json:"modalities"`
// 使用的模型名称
Model string `json:"model,omitempty"`
// 是否并行调用工具
ParallelToolCalls bool `json:"parallel_tool_calls"`
// 响应格式
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
// 是否使用流式响应
Stream bool `json:"stream"`
// 流式选项
StreamOptions map[string]any `json:"stream_options,omitempty"`
// 工具选择策略,默认值为 "auto"
ToolChoice any `json:"tool_choice,omitempty"`
// 工具列表
Tools []Tool `json:"tools,omitempty"`
}
ImageContent struct {
ImageURL string `json:"image_url"`
Type string `json:"type"`
}
Message struct {
Audio map[string]string `json:"audio,omitempty"`
Content any `json:"content"`
Role string `json:"role"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCalls []map[string]any `json:"tool_calls,omitempty"`
}
ResponseFormat struct {
JSONSchema map[string]string `json:"json_schema,omitempty"`
Type string `json:"type"`
}
TextContent struct {
Text string `json:"text"`
Type string `json:"type"`
}
Tool struct {
Function ToolFunction `json:"function"`
Type string `json:"type"`
}
ToolChoice struct {
Function map[string]any `json:"function,omitempty"`
Type string `json:"type"`
}
ToolFunction struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters map[string]any `json:"parameters,omitempty"`
Strict bool `json:"strict"`
}
)
var waitingMessages = []string{
"Just a moment, I'm thinking...",
"Let me think about that for a second...",
"Good question, let me find out...",
}
// Chat Completion 服务器
type Server struct {
client *openai.Client
logger *slog.Logger
}
// 创建一个新的服务器实例
func NewServer(apiKey string) *Server {
return &Server{
client: openai.NewClient(apiKey),
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
}
}
// 处理 Chat Completion 端点
func (s *Server) handleChatCompletion(c *gin.Context) {
var request ChatCompletionRequest
if err := c.ShouldBindJSON(&request); err != nil {
s.sendError(c, http.StatusBadRequest, err)
return
}
if !request.Stream {
s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming"))
return
}
// 设置 SSE 头部
c.Header("Content-Type", "text/event-stream")
responseChan := make(chan any, 100)
errorChan := make(chan error, 1)
go func() {
messages := make([]openai.ChatCompletionMessage, len(request.Messages))
for i, msg := range request.Messages {
if strContent, ok := msg.Content.(string); ok {
messages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
Content: strContent,
}
}
}
req := openai.ChatCompletionRequest{
Model: request.Model,
Messages: messages,
Stream: true,
}
if len(request.Tools) > 0 {
tools := make([]openai.Tool, len(request.Tools))
for i, tool := range request.Tools {
tools[i] = openai.Tool{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
},
}
}
req.Tools = tools
}
stream, err := s.client.CreateChatCompletionStream(c.Request.Context(), req)
if err != nil {
errorChan <- err
return
}
defer stream.Close()
for {
response, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
errorChan <- err
return
}
responseChan <- response
}
close(responseChan)
}()
for {
select {
case chunk, ok := <-responseChan:
if !ok {
c.SSEvent("data", "[DONE]")
return
}
data, _ := json.Marshal(chunk)
c.SSEvent("data", string(data))
case err := <-errorChan:
s.logger.Error("Error in chat completion stream", "err", err)
s.sendError(c, http.StatusInternalServerError, err)
return
}
}
}
// 发送错误响应到客户端
func (s *Server) sendError(c *gin.Context, status int, err error) {
c.JSON(status, gin.H{"detail": err.Error()})
}
func main() {
// 初始化服务器
server := NewServer(os.Getenv("YOUR_LLM_API_KEY"))
// 初始化 Gin 路由
r := gin.Default()
// 设置路由
r.POST("/chat/completions", server.handleChatCompletion)
// 启动服务器
r.Run(":8000")
}
配置声网对话式 AI 引擎
在调用 POST 创建对话式智能体时,将 LLM 配置指向你的自定义服务:
如果访问你的自定义大模型服务需要校验身份,请在 api_key
字段中传入你的认证信息。
{
"llm": {
"url": "https://your-custom-llm-service/chat/completions",
"api_key": "",
"system_messages": [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
}
}
进阶功能
实现检索增强生成
如果你想提升智能体响应的准确性和相关性,可以利用检索增强生成(RAG)功能,让你的自定义大模型从特定知识库中检索信息,再将检索结果作为上下文提供给大模型生成回答。
以下示例代码模拟了从知识库中检索并返回内容的流程,并创建了 /rag/chat/completions
接口用于使用 RAG 检索结果调用 LLM 生成回答:
- Python
- Go
async def perform_rag_retrieval(messages: Optional[Dict]) -> str:
"""
使用 RAG 模型从知识库消息列表中检索相关内容
Args:
messages: 原始消息列表
Returns:
str: 检索到的文本内容
"""
# TODO: 实现实际的 RAG 检索逻辑
# 你可以根据具体需求从消息列表中选择第一个或最后一个消息作为查询,然后发送查询到 RAG 模型检索相关内容
# 返回检索结果
return "This is relevant content retrieved from the knowledge base."
def refact_messages(context: str, messages: Optional[Dict] = None) -> Optional[Dict]:
"""
调整消息列表,将检索到的上下文添加到原始消息列表中
Args:
context: 检索到的上下文
messages: 原始消息列表
Returns:
List: 调整后的消息列表
"""
# TODO: 实现实际的消息调整逻辑
# 这应该将检索到的上下文添加到原始消息列表中
return messages
# 随机的等待消息
waiting_messages = [
"Just a moment, I'm thinking...",
"Let me think about that for a second...",
"Good question, let me find out...",
]
@app.post("/rag/chat/completions")
async def create_rag_chat_completion(request: ChatCompletionRequest):
try:
logger.info(f"Received RAG request: {request.model_dump_json()}")
if not request.stream:
raise HTTPException(
status_code=400, detail="chat completions require streaming"
)
async def generate():
# First send a "please wait" prompt
waiting_message = {
"id": "waiting_msg",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": random.choice(waiting_messages),
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(waiting_message)}\n\n"
# Perform RAG retrieval
retrieved_context = await perform_rag_retrieval(request.messages)
# Adjust messages
refacted_messages = refact_messages(retrieved_context, request.messages)
# Request LLM completion
client = AsyncOpenAI(api_key=os.getenv("<YOUR_LLM_API_KEY>"))
response = await client.chat.completions.create(
model=request.model,
messages=refacted_messages,
tool_choice=(
request.tool_choice
if request.tools and request.tool_choice
else None
),
tools=request.tools if request.tools else None,
modalities=request.modalities,
audio=request.audio,
response_format=request.response_format,
stream=True, # Force streaming
stream_options=request.stream_options,
)
try:
async for chunk in response:
logger.debug(f"Received RAG chunk: {chunk}")
yield f"data: {json.dumps(chunk.to_dict())}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
logger.info("RAG stream was cancelled")
raise
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("RAG request was cancelled")
raise HTTPException(status_code=499, detail="Request was cancelled")
except Exception as e:
traceback_str = "".join(traceback.format_tb(e.__traceback__))
error_message = f"{str(e)}\n{traceback_str}"
logger.error(error_message)
raise HTTPException(status_code=500, detail=error_message)
// 处理 RAG Chat Completion 端点
func (s *Server) handleRAGChatCompletion(c *gin.Context) {
var request ChatCompletionRequest
if err := c.ShouldBindJSON(&request); err != nil {
s.sendError(c, http.StatusBadRequest, err)
return
}
if !request.Stream {
s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming"))
return
}
// 设置 SSE 头部
c.Header("Content-Type", "text/event-stream")
// 首先发送一个 "请稍等" 提示
waitingMsg := map[string]any{
"id": "waiting_msg",
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{
"role": "assistant",
"content": waitingMessages[rand.Intn(len(waitingMessages))],
},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(waitingMsg)
c.SSEvent("data", string(data))
// 执行 RAG 检索
retrievedContext, err := s.performRAGRetrieval(request.Messages)
if err != nil {
s.logger.Error("Failed to perform RAG retrieval", "err", err)
s.sendError(c, http.StatusInternalServerError, err)
return
}
// 调整消息
refactedMessages := s.refactMessages(retrievedContext, request.Messages)
// 转换消息为 OpenAI 格式
messages := make([]openai.ChatCompletionMessage, len(refactedMessages))
for i, msg := range refactedMessages {
if strContent, ok := msg.Content.(string); ok {
messages[i] = openai.ChatCompletionMessage{
Role: msg.Role,
Content: strContent,
}
}
}
req := openai.ChatCompletionRequest{
Model: request.Model,
Messages: messages,
Stream: true,
}
stream, err := s.client.CreateChatCompletionStream(c.Request.Context(), req)
if err != nil {
s.sendError(c, http.StatusInternalServerError, err)
return
}
defer stream.Close()
for {
response, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
s.sendError(c, http.StatusInternalServerError, err)
return
}
data, _ := json.Marshal(response)
c.SSEvent("data", string(data))
}
c.SSEvent("data", "[DONE]")
}
// performRAGRetrieval 使用 RAG 模型从知识库消息列表中检索相关内容。
//
// messages: 包含原始消息列表。
//
// 返回检索到的文本内容和任何检索过程中发生的错误。
func (s *Server) performRAGRetrieval(messages []Message) (string, error) {
// TODO: 实现实际的 RAG 检索逻辑
// 你可能需要根据具体需求从消息列表中选择第一个或最后一个消息作为查询,然后发送查询到 RAG 模型检索相关内容
// 返回检索结果
return "This is relevant content retrieved from the knowledge base.", nil
}
// refactMessages 调整消息列表,将检索到的上下文添加到原始消息列表中。
//
// context: 包含检索到的上下文。
// messages: 包含原始消息列表。
//
// 返回调整后的消息列表。
func (s *Server) refactMessages(context string, messages []Message) []Message {
// TODO: 实现实际的消息调整逻辑
// 这应该将检索到的上下文添加到原始消息列表中
// 只返回原始消息
return messages
}
在调用 POST 创建对话式智能体时,只需将 LLM URL 指向你的 RAG 接口:
如果访问你的自定义大模型服务需要校验身份,请在 api_key
字段中传入你的认证信息。
{
"llm": {
"url": "http://your-custom-llm-service/rag/chat/completions",
"api_key": ""
"system_messages": [
{
"role": "system",
"content": "Please answer the user's question based on the following retrieved information: ..."
}
]
}
}
实现多模态能力
声网对话式 AI 引擎支持大模型以多模态形式(文本和音频)输出,你可以创建专用的多模态接口以实现更多个性化需求。
如果你想了解更多关于使用音频模态输出的信息,你可以阅读使用音频模态输出。
以下示例代码展示了通过读取文本和音频文件,并发送给大模型生成音频回复的流程:
- Python
- Go
async def read_text_file(file_path: str) -> str:
"""
读取文本文件并返回内容
Args:
file_path: 文本文件的路径
Returns:
str: 文本文件的内容
"""
async with aiofiles.open(file_path, "r") as file:
content = await file.read()
return content
async def read_pcm_file(
file_path: str, sample_rate: int, duration_ms: int
) -> List[bytes]:
"""
读取 PCM 文件并返回音频块列表
Args:
file_path: PCM 文件的路径
sample_rate: 音频的采样率
duration_ms: 每个音频块的时长,单位为毫秒
Returns:
List: 音频块列表
"""
async with aiofiles.open(file_path, "rb") as file:
content = await file.read()
chunk_size = int(sample_rate * 2 * (duration_ms / 1000))
return [content[i : i + chunk_size] for i in range(0, len(content), chunk_size)]
@app.post("/audio/chat/completions")
async def create_audio_chat_completion(request: ChatCompletionRequest):
try:
logger.info(f"Received audio request: {request.model_dump_json()}")
if not request.stream:
raise HTTPException(
status_code=400, detail="chat completions require streaming"
)
# 示例用法,读取文本和音频文件
# 请替换为你的实际逻辑
text_file_path = "./file.txt"
pcm_file_path = "./file.pcm"
sample_rate = 16000 # 示例采样率
duration_ms = 40 # 40ms 音频块
text_content = await read_text_file(text_file_path)
audio_chunks = await read_pcm_file(pcm_file_path, sample_rate, duration_ms)
async def generate():
try:
# 发送文本内容
audio_id = uuid.uuid4().hex
text_message = {
"id": uuid.uuid4().hex,
"choices": [
{
"index": 0,
"delta": {
"audio": {
"id": audio_id,
"transcript": text_content,
},
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(text_message)}\n\n"
# 发送音频块
for chunk in audio_chunks:
audio_message = {
"id": uuid.uuid4().hex,
"choices": [
{
"index": 0,
"delta": {
"audio": {
"id": audio_id,
"data": base64.b64encode(chunk).decode("utf-8"),
},
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(audio_message)}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
logger.info("Audio stream was cancelled")
raise
return StreamingResponse(generate(), media_type="text/event-stream")
except asyncio.CancelledError:
logger.info("Audio request was cancelled")
raise HTTPException(status_code=499, detail="Request was cancelled")
except Exception as e:
traceback_str = "".join(traceback.format_tb(e.__traceback__))
error_message = f"{str(e)}\n{traceback_str}"
logger.error(error_message)
raise HTTPException(status_code=500, detail=error_message)
// 处理音频 Chat Completion 端点
func (s *Server) handleAudioChatCompletion(c *gin.Context) {
var request ChatCompletionRequest
if err := c.ShouldBindJSON(&request); err != nil {
s.sendError(c, http.StatusBadRequest, err)
return
}
if !request.Stream {
s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming"))
return
}
// 设置 SSE 头部
c.Header("Content-Type", "text/event-stream")
// 读取文本和音频文件
textContent, err := s.readTextFile("./file.txt")
if err != nil {
s.logger.Error("Failed to read text file", "err", err)
s.sendError(c, http.StatusInternalServerError, err)
return
}
sampleRate := 16000 // Example sample rate
durationMs := 40 // 40ms chunks
audioChunks, err := s.readPCMFile("./file.pcm", sampleRate, durationMs)
if err != nil {
s.logger.Error("Failed to read PCM file", "err", err)
s.sendError(c, http.StatusInternalServerError, err)
return
}
// 发送文本内容
audioID := uuid.New().String()
textMessage := map[string]any{
"id": uuid.New().String(),
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{
"audio": map[string]any{
"id": audioID,
"transcript": textContent,
},
},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(textMessage)
c.SSEvent("data", string(data))
// 发送音频块
for _, chunk := range audioChunks {
audioMessage := map[string]any{
"id": uuid.New().String(),
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{
"audio": map[string]any{
"id": audioID,
"data": base64.StdEncoding.EncodeToString(chunk),
},
},
"finish_reason": nil,
},
},
}
data, _ := json.Marshal(audioMessage)
c.SSEvent("data", string(data))
}
c.SSEvent("data", "[DONE]")
}
// readPCMFile 读取 PCM 文件并返回音频块。
//
// filePath: 指定 PCM 文件的路径。
// sampleRate: 指定音频的采样率。
// durationMs: 指定每个音频块的时长,单位为毫秒。
//
// 返回音频块列表和任何读取过程中发生的错误。
func (s *Server) readPCMFile(filePath string, sampleRate int, durationMs int) ([][]byte, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read PCM file: %w", err)
}
chunkSize := int(float64(sampleRate) * 2 * float64(durationMs) / 1000.0)
if chunkSize == 0 {
return nil, fmt.Errorf("invalid chunk size: sample rate %d, duration %dms", sampleRate, durationMs)
}
chunks := make([][]byte, 0, len(data)/chunkSize+1)
for i := 0; i < len(data); i += chunkSize {
end := i + chunkSize
if end > len(data) {
end = len(data)
}
chunks = append(chunks, data[i:end])
}
return chunks, nil
}
// readTextFile 读取文本文件并返回其内容。
//
// filePath: 指定文本文件的路径。
//
// 返回文本文件的内容和任何读取过程中发生的错误。
func (s *Server) readTextFile(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("failed to read text file: %w", err)
}
return string(data), nil
}
调用 POST 创建对话式智能体时,参考如下配置:
{
"llm": {
"url": "http://your-custom-llm-service/audio/chat/completions",
"api_key": "your_api_key",
"input_modalities": ["text"],
"output_modalities": ["text", "audio"]
"system_messages": [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
}
}
参考信息
示例项目
声网提供了开源的示例项目供你参考,你可以前往下载或查看其中的源代码。
接口标准
自定义大模型服务必须兼容 OpenAI Chat Completions API 的接口标准:
- 请求格式:包含模型、消息、工具调用配置等参数
- 响应格式:包含模型生成的回复、元数据等信息
- 流式响应:符合 SSE (Server-Sent Events) 规范
详细接口标准可参考: