diff --git a/src/components/chat/AgentComposer.tsx b/src/components/chat/AgentComposer.tsx index 1f4b558..4dd3fbf 100644 --- a/src/components/chat/AgentComposer.tsx +++ b/src/components/chat/AgentComposer.tsx @@ -7,8 +7,11 @@ import { Box, Chip, Collapse, + FormControl, IconButton, + MenuItem, Paper, + Select, Stack, TextField, Typography, @@ -21,6 +24,9 @@ import MicRounded from "@mui/icons-material/MicRounded"; import KeyboardArrowDownRounded from "@mui/icons-material/KeyboardArrowDownRounded"; import KeyboardArrowUpRounded from "@mui/icons-material/KeyboardArrowUpRounded"; import AttachFileRounded from "@mui/icons-material/AttachFileRounded"; +import BoltRounded from "@mui/icons-material/BoltRounded"; +import AutoAwesomeRounded from "@mui/icons-material/AutoAwesomeRounded"; +import type { AgentModel } from "@/lib/chatStream"; type AgentComposerProps = { input: string; @@ -36,6 +42,8 @@ type AgentComposerProps = { onStartListening: () => void; onStopListening: () => void; onPresetSelect: (prompt: string) => void; + selectedModel: AgentModel; + onModelChange: (model: AgentModel) => void; }; export const AgentComposer = ({ @@ -52,6 +60,8 @@ export const AgentComposer = ({ onStartListening, onStopListening, onPresetSelect, + selectedModel, + onModelChange, }: AgentComposerProps) => { const theme = useTheme(); const canSend = input.trim().length > 0 && !isStreaming && !isHydrating; @@ -213,46 +223,163 @@ export const AgentComposer = ({ ) : null} - - {isStreaming ? ( - - - - - - ) : ( - - - - - - )} - + + + + + + + {isStreaming ? ( + + + + + + ) : ( + + + + + + )} + + diff --git a/src/components/chat/GlobalChatbox.tsx b/src/components/chat/GlobalChatbox.tsx index 21b87ae..fd0c745 100644 --- a/src/components/chat/GlobalChatbox.tsx +++ b/src/components/chat/GlobalChatbox.tsx @@ -3,6 +3,7 @@ import React, { useCallback, useEffect, useRef, useState } from "react"; import { Box, Drawer, alpha, useTheme } from "@mui/material"; +import type { AgentModel } from "@/lib/chatStream"; import { AgentComposer } from "./AgentComposer"; import { AgentHeader } from "./AgentHeader"; import { AgentHistoryPanel } from "./AgentHistoryPanel"; @@ -19,6 +20,9 @@ export const GlobalChatbox: React.FC = ({ open, onClose }) => { const [width, setWidth] = useState(520); const [isResizing, setIsResizing] = useState(false); const [isHistoryOpen, setIsHistoryOpen] = useState(false); + const [selectedModel, setSelectedModel] = useState( + "deepseek/deepseek-v4-pro", + ); const bottomRef = useRef(null); const inputRef = useRef(null); @@ -65,6 +69,7 @@ export const GlobalChatbox: React.FC = ({ open, onClose }) => { } = useAgentChatSession({ onToolCall: handleToolCall, onBeforeSend: stopListening, + getModel: () => selectedModel, }); useEffect(() => { @@ -298,6 +303,8 @@ export const GlobalChatbox: React.FC = ({ open, onClose }) => { onStartListening={startListening} onStopListening={stopListening} onPresetSelect={handlePresetPromptSelect} + selectedModel={selectedModel} + onModelChange={setSelectedModel} /> diff --git a/src/components/chat/hooks/useAgentChatSession.ts b/src/components/chat/hooks/useAgentChatSession.ts index e614d57..a2ff7cf 100644 --- a/src/components/chat/hooks/useAgentChatSession.ts +++ b/src/components/chat/hooks/useAgentChatSession.ts @@ -3,7 +3,7 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { abortAgentChat, forkAgentChat, streamAgentChat } from "@/lib/chatStream"; -import type { StreamEvent } from "@/lib/chatStream"; +import type { AgentModel, StreamEvent } from "@/lib/chatStream"; import type { AgentArtifact, BranchGroup, @@ -37,6 +37,7 @@ type UseAgentChatSessionOptions = { }, ) => void; onBeforeSend?: () => void; + getModel?: () => AgentModel; }; type PromptRunOptions = { @@ -137,6 +138,7 @@ const messagesEqual = (left: Message[], right: Message[]) => export const useAgentChatSession = ({ onToolCall, onBeforeSend, + getModel, }: UseAgentChatSessionOptions) => { const storageSessionIdRef = useRef(undefined); const hydrationCompletedRef = useRef(false); @@ -317,6 +319,7 @@ export const useAgentChatSession = ({ await streamAgentChat({ message: prompt, sessionId: sessionIdOverride ?? sessionIdRef.current, + model: getModel?.(), signal: controller.signal, onEvent: (event) => { if ("sessionId" in event && event.sessionId && event.sessionId !== sessionIdRef.current) { @@ -448,7 +451,7 @@ export const useAgentChatSession = ({ setIsStreaming(false); } }, - [appendArtifact, isHydrating, isStreaming, messages, onBeforeSend, onToolCall], + [appendArtifact, getModel, isHydrating, isStreaming, messages, onBeforeSend, onToolCall], ); const abort = useCallback(() => { diff --git a/src/lib/chatStream.test.ts b/src/lib/chatStream.test.ts index 6cf3f16..064ead6 100644 --- a/src/lib/chatStream.test.ts +++ b/src/lib/chatStream.test.ts @@ -51,6 +51,7 @@ describe("streamAgentChat", () => { await streamAgentChat({ message: "hi", + model: "deepseek/deepseek-v4-pro", onEvent: (event) => events.push(event), }); @@ -60,6 +61,11 @@ describe("streamAgentChat", () => { method: "POST", projectHeaderMode: "include", skipAuthRedirect: true, + body: JSON.stringify({ + message: "hi", + session_id: undefined, + model: "deepseek/deepseek-v4-pro", + }), }), ); diff --git a/src/lib/chatStream.ts b/src/lib/chatStream.ts index eec1856..9a6d981 100644 --- a/src/lib/chatStream.ts +++ b/src/lib/chatStream.ts @@ -1,6 +1,10 @@ import { apiFetch } from "@/lib/apiFetch"; import { config } from "@config/config"; +export type AgentModel = + | "deepseek/deepseek-v4-flash" + | "deepseek/deepseek-v4-pro"; + export type StreamEvent = | { type: "token"; sessionId: string; content: string } | { type: "done"; sessionId: string; totalDurationMs?: number } @@ -35,6 +39,7 @@ export type StreamEvent = type StreamOptions = { message: string; sessionId?: string; + model?: AgentModel; signal?: AbortSignal; onEvent: (event: StreamEvent) => void; }; @@ -85,6 +90,7 @@ const resolveToolParams = ( export const streamAgentChat = async ({ message, sessionId, + model, signal, onEvent, }: StreamOptions) => { @@ -102,6 +108,7 @@ export const streamAgentChat = async ({ body: JSON.stringify({ message, session_id: sessionId, + model, }), projectHeaderMode: "include", userHeaderMode: "include",