diff --git a/src/components/chat/AgentWorkspace.test.tsx b/src/components/chat/AgentWorkspace.test.tsx new file mode 100644 index 0000000..04b4c4f --- /dev/null +++ b/src/components/chat/AgentWorkspace.test.tsx @@ -0,0 +1,97 @@ +/* eslint-disable @next/next/no-img-element */ +import "@testing-library/jest-dom"; +import React from "react"; +import { render } from "@testing-library/react"; + +import { AgentWorkspace } from "./AgentWorkspace"; +import type { Message } from "./GlobalChatbox.types"; + +const renderCounts = new Map(); + +jest.mock("next/image", () => ({ + __esModule: true, + default: (props: React.ImgHTMLAttributes) => {props.alt, +})); + +jest.mock("framer-motion", () => ({ + AnimatePresence: ({ children }: { children: React.ReactNode }) => <>{children}, + motion: { + div: ({ children, ...props }: React.HTMLAttributes) =>
{children}
, + }, +})); + +jest.mock("./GlobalChatbox.parts", () => ({ + TypingIndicator: () =>
typing
, +})); + +jest.mock("./AgentTurn", () => ({ + AgentTurn: ({ message }: { message: Message }) => { + renderCounts.set(message.id, (renderCounts.get(message.id) ?? 0) + 1); + return
{message.content}
; + }, +})); + +describe("AgentWorkspace", () => { + const defaultProps = { + branchGroups: [], + branchTransition: null, + bottomRef: { current: null }, + speakingMessageId: null, + speechState: "idle" as const, + onSpeak: jest.fn(), + onPauseSpeech: jest.fn(), + onResumeSpeech: jest.fn(), + onStopSpeech: jest.fn(), + isTtsSupported: false, + onRegenerate: jest.fn(), + onEditResubmit: jest.fn(), + onCycleBranch: jest.fn(), + }; + + beforeEach(() => { + renderCounts.clear(); + }); + + it("keeps stable history turns from re-rendering while the last assistant message streams", () => { + const userMessage: Message = { + id: "user-1", + role: "user", + content: "question", + }; + const assistantHistoryMessage: Message = { + id: "assistant-1", + role: "assistant", + content: "stable answer", + }; + const streamingMessage: Message = { + id: "assistant-2", + role: "assistant", + content: "partial", + }; + + const { rerender } = render( + , + ); + + const updatedStreamingMessage: Message = { + ...streamingMessage, + content: "partial with more tokens", + }; + + rerender( + , + ); + + expect(renderCounts.get("user-1")).toBe(1); + expect(renderCounts.get("assistant-1")).toBe(1); + expect(renderCounts.get("assistant-2")).toBe(2); + }); +}); diff --git a/src/components/chat/AgentWorkspace.tsx b/src/components/chat/AgentWorkspace.tsx index 06eafb9..5f3e03f 100644 --- a/src/components/chat/AgentWorkspace.tsx +++ b/src/components/chat/AgentWorkspace.tsx @@ -13,6 +13,7 @@ import { AgentTurn } from "./AgentTurn"; import { TypingIndicator } from "./GlobalChatbox.parts"; import type { BranchGroup, + BranchState, BranchTransition, Message, SpeechState, @@ -36,6 +37,96 @@ type AgentWorkspaceProps = { onCycleBranch: (rootMessageId: string, direction: -1 | 1) => void; }; +type TurnListProps = { + messages: Message[]; + branchGroups: BranchGroup[]; + speakingMessageId: string | null; + speechState: SpeechState; + onSpeak: (messageId: string, text: string) => void; + onPauseSpeech: () => void; + onResumeSpeech: () => void; + onStopSpeech: () => void; + isTtsSupported: boolean; + onRegenerate: () => void; + onEditResubmit: (messageId: string, newContent: string) => void; + onCycleBranch: (rootMessageId: string, direction: -1 | 1) => void; +}; + +const sameMessages = (left: Message[], right: Message[]) => + left.length === right.length && + left.every((message, index) => message === right[index]); + +const TurnListInner = ({ + messages, + branchGroups, + speakingMessageId, + speechState, + onSpeak, + onPauseSpeech, + onResumeSpeech, + onStopSpeech, + isTtsSupported, + onRegenerate, + onEditResubmit, + onCycleBranch, +}: TurnListProps) => { + const branchStateByRootId = React.useMemo(() => { + const next = new Map(); + branchGroups.forEach((group) => { + if (group.branches.length > 1) { + next.set(group.rootMessageId, { + activeIndex: group.activeIndex, + total: group.branches.length, + }); + } + }); + return next; + }, [branchGroups]); + + return ( + <> + {messages.map((message) => { + const rootMessageId = message.branchRootId ?? message.id; + return ( + + ); + })} + + ); +}; + +const TurnList = React.memo( + TurnListInner, + (prevProps, nextProps) => + sameMessages(prevProps.messages, nextProps.messages) && + prevProps.branchGroups === nextProps.branchGroups && + prevProps.speakingMessageId === nextProps.speakingMessageId && + prevProps.speechState === nextProps.speechState && + prevProps.onSpeak === nextProps.onSpeak && + prevProps.onPauseSpeech === nextProps.onPauseSpeech && + prevProps.onResumeSpeech === nextProps.onResumeSpeech && + prevProps.onStopSpeech === nextProps.onStopSpeech && + prevProps.isTtsSupported === nextProps.isTtsSupported && + prevProps.onRegenerate === nextProps.onRegenerate && + prevProps.onEditResubmit === nextProps.onEditResubmit && + prevProps.onCycleBranch === nextProps.onCycleBranch, +); + +TurnList.displayName = "TurnList"; + const EmptyState = () => { const theme = useTheme(); const capabilities = [ @@ -182,37 +273,12 @@ export const AgentWorkspace = ({ const transitionMessages = branchTransition ? messages.slice(branchTransition.parentCount) : []; - - const renderTurn = (message: Message) => { - const rootMessageId = message.branchRootId ?? message.id; - const branchGroup = branchGroups.find( - (group) => group.rootMessageId === rootMessageId, - ); - - return ( - 1 - ? { - activeIndex: branchGroup.activeIndex, - total: branchGroup.branches.length, - } - : undefined - } - messageSpeechState={speakingMessageId === message.id ? speechState : "idle"} - onSpeak={onSpeak} - onPause={onPauseSpeech} - onResume={onResumeSpeech} - onStopSpeech={onStopSpeech} - isTtsSupported={isTtsSupported} - onRegenerate={onRegenerate} - onEditResubmit={onEditResubmit} - onCycleBranch={onCycleBranch} - /> - ); - }; + const streamingMessage = + !branchTransition && isStreaming && messages.at(-1)?.role === "assistant" + ? messages.at(-1) + : undefined; + const historyMessages = + streamingMessage !== undefined ? messages.slice(0, -1) : stableMessages; return ( 0 ? ( - {stableMessages.map(renderTurn)} + + + {streamingMessage ? ( + + ) : null} {branchTransition ? ( @@ -244,7 +340,20 @@ export const AgentWorkspace = ({ transition={{ duration: 0.18, ease: "easeOut" }} style={{ display: "flex", flexDirection: "column", gap: 16 }} > - {transitionMessages.map(renderTurn)} + ) : null}