diff --git a/packages/opencode/src/cli/cmd/tui/context/sync.tsx b/packages/opencode/src/cli/cmd/tui/context/sync.tsx index 269ed7ae0bd..e4734f57bf5 100644 --- a/packages/opencode/src/cli/cmd/tui/context/sync.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/sync.tsx @@ -413,6 +413,18 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ sdk.client.provider.auth().then((x) => setStore("provider_auth", reconcile(x.data ?? {}))), sdk.client.vcs.get().then((x) => setStore("vcs", reconcile(x.data))), sdk.client.path.get().then((x) => setStore("path", reconcile(x.data!))), + sdk.client.question.list().then((x) => { + if (x.data) { + const grouped: Record = {} + for (const q of x.data) { + if (!grouped[q.sessionID]) grouped[q.sessionID] = [] + grouped[q.sessionID].push(q) + } + for (const [id, qs] of Object.entries(grouped)) { + setStore("question", id, reconcile(qs.sort((a, b) => a.id.localeCompare(b.id)))) + } + } + }), ]).then(() => { setStore("status", "complete") }) diff --git a/packages/opencode/src/question/index.ts b/packages/opencode/src/question/index.ts index c93b74b9a40..dab300b7737 100644 --- a/packages/opencode/src/question/index.ts +++ b/packages/opencode/src/question/index.ts @@ -3,6 +3,10 @@ import { BusEvent } from "@/bus/bus-event" import { Identifier } from "@/id/id" import { Instance } from "@/project/instance" import { Log } from "@/util/log" +import { Database, eq } from "@/storage/db" +import { PartTable, SessionTable } from "../session/session.sql" +import { MessageV2 } from "../session/message-v2" +import { Filesystem } from "../util/filesystem" import z from "zod" export namespace Question { @@ -85,15 +89,102 @@ export namespace Question { { info: Request resolve: (answers: Answer[]) => void - reject: (e: any) => void + reject: (e: Error) => void + recovered?: { partID: string } } > = {} return { pending, + didRecover: false, } }) + async function recover() { + const s = await state() + if (s.didRecover) return + s.didRecover = true + + const allParts = Database.use((db) => + db + .select({ + id: PartTable.id, + messageID: PartTable.message_id, + sessionID: PartTable.session_id, + directory: SessionTable.directory, + data: PartTable.data, + }) + .from(PartTable) + .innerJoin(SessionTable, eq(PartTable.session_id, SessionTable.id)) + .all(), + ) + + const candidates = allParts.filter((p) => { + const data = p.data as MessageV2.Part + if (data.type !== "tool" || data.tool !== "question") return false + const isRunning = data.state.status === "running" + const isAborted = data.state.status === "error" && data.state.error === "Tool execution aborted" + return isRunning || isAborted + }) + + const dir = Instance.directory + const recoverable = candidates.filter((q) => { + const id = q.id.replace("prt_", "que_") + if (s.pending[id]) return false + return Filesystem.overlaps(q.directory, dir) + }) + + if (recoverable.length === 0) return + + log.info("recovering questions", { count: recoverable.length }) + for (const row of recoverable) { + const data = row.data as MessageV2.ToolPart + const id = row.id.replace("prt_", "que_") + + const info: Request = { + id, + sessionID: row.sessionID, + questions: (data.state.input as { questions: Info[] }).questions, + tool: { + messageID: row.messageID, + callID: data.callID, + }, + } + s.pending[id] = { + info, + resolve: () => {}, + reject: () => {}, + recovered: { partID: row.id }, + } + + if (data.state.status === "error") { + data.state = { + status: "running", + input: data.state.input, + time: { + start: data.state.time.start, + }, + } + Database.use((db) => { + db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run() + Bus.publish(MessageV2.Event.PartUpdated, { + part: { ...data, id: row.id, messageID: row.messageID, sessionID: row.sessionID } satisfies MessageV2.Part, + }) + }) + } + + await Bus.publish(Event.Asked, info) + } + } + + export function format(questions: Info[], answers: Answer[]) { + const fmt = (answer: Answer | undefined) => { + if (!answer?.length) return "Unanswered" + return answer.join(", ") + } + return questions.map((q, i) => `"${q.question}"="${fmt(answers[i])}"`).join(", ") + } + export async function ask(input: { sessionID: string questions: Info[] @@ -137,6 +228,43 @@ export namespace Question { answers: input.answers, }) + if (existing.recovered && existing.info.tool) { + const partID = existing.recovered.partID + const list = existing.info.questions + const formatted = format(list, input.answers) + const output = `User has answered your questions: ${formatted}. You can now continue with the user's answers in mind.` + const title = `Asked ${list.length} question${list.length > 1 ? "s" : ""}` + + Database.use((db) => { + const row = db.select().from(PartTable).where(eq(PartTable.id, partID)).get() + if (!row) return + const data = row.data as MessageV2.ToolPart + const state = data.state + const start = state.status === "pending" ? Date.now() : state.time.start + data.state = { + status: "completed", + input: state.input, + output, + title, + metadata: { + answers: input.answers, + }, + time: { + start, + end: Date.now(), + }, + } + const part = { + ...data, + id: row.id, + messageID: row.message_id, + sessionID: row.session_id, + } satisfies MessageV2.Part + db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run() + Bus.publish(MessageV2.Event.PartUpdated, { part }) + }) + } + existing.resolve(input.answers) } @@ -151,11 +279,39 @@ export namespace Question { log.info("rejected", { requestID }) - Bus.publish(Event.Rejected, { + await Bus.publish(Event.Rejected, { sessionID: existing.info.sessionID, requestID: existing.info.id, }) + if (existing.recovered && existing.info.tool) { + const partID = existing.recovered.partID + Database.use((db) => { + const row = db.select().from(PartTable).where(eq(PartTable.id, partID)).get() + if (!row) return + const data = row.data as MessageV2.ToolPart + const state = data.state + const start = state.status === "pending" ? Date.now() : state.time.start + data.state = { + status: "error", + input: state.input, + error: "The user dismissed this question", + time: { + start, + end: Date.now(), + }, + } + const part = { + ...data, + id: row.id, + messageID: row.message_id, + sessionID: row.session_id, + } satisfies MessageV2.Part + db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run() + Bus.publish(MessageV2.Event.PartUpdated, { part }) + }) + } + existing.reject(new RejectedError()) } @@ -166,6 +322,7 @@ export namespace Question { } export async function list() { + await recover() return state().then((x) => Object.values(x.pending).map((x) => x.info)) } } diff --git a/packages/opencode/src/tool/question.ts b/packages/opencode/src/tool/question.ts index a2887546d4b..23005e971df 100644 --- a/packages/opencode/src/tool/question.ts +++ b/packages/opencode/src/tool/question.ts @@ -15,12 +15,7 @@ export const QuestionTool = Tool.define("question", { tool: ctx.callID ? { messageID: ctx.messageID, callID: ctx.callID } : undefined, }) - function format(answer: Question.Answer | undefined) { - if (!answer?.length) return "Unanswered" - return answer.join(", ") - } - - const formatted = params.questions.map((q, i) => `"${q.question}"="${format(answers[i])}"`).join(", ") + const formatted = Question.format(params.questions, answers) return { title: `Asked ${params.questions.length} question${params.questions.length > 1 ? "s" : ""}`, diff --git a/packages/opencode/test/question/recover.test.ts b/packages/opencode/test/question/recover.test.ts new file mode 100644 index 00000000000..05c8f348b45 --- /dev/null +++ b/packages/opencode/test/question/recover.test.ts @@ -0,0 +1,302 @@ +import { test, expect, afterEach } from "bun:test" +import { Question } from "../../src/question" +import { Instance } from "../../src/project/instance" +import { Database, eq } from "../../src/storage/db" +import { Identifier } from "../../src/id/id" +import { SessionTable, MessageTable, PartTable } from "../../src/session/session.sql" +import { tmpdir } from "../fixture/fixture" +import { resetDatabase } from "../fixture/db" +import type { MessageV2 } from "../../src/session/message-v2" + +afterEach(async () => { + await resetDatabase() +}) + +function seed(input: { projectID: string; sessionID: string; directory: string }) { + const now = Date.now() + const messageID = Identifier.ascending("message") + + Database.use((db) => { + db.insert(SessionTable) + .values({ + id: input.sessionID, + project_id: input.projectID, + slug: "test", + directory: input.directory, + title: "test session", + version: "2", + time_created: now, + time_updated: now, + }) + .run() + + db.insert(MessageTable) + .values({ + id: messageID, + session_id: input.sessionID, + time_created: now, + time_updated: now, + data: { + role: "assistant" as const, + time: { created: now }, + parentID: "msg_fake", + modelID: "test-model", + providerID: "test-provider", + mode: "default", + agent: "coder", + path: { cwd: input.directory, root: input.directory }, + cost: 0, + tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + } as typeof MessageTable.$inferInsert.data, + }) + .run() + }) + + return messageID +} + +function insertQuestionPart(input: { + sessionID: string + messageID: string + status: "running" | "aborted" | "completed" +}) { + const partID = Identifier.ascending("part") + const now = Date.now() + + const questions = [ + { + question: "Pick something?", + header: "Choice", + options: [ + { label: "A", description: "Option A" }, + { label: "B", description: "Option B" }, + ], + }, + ] + + const state = + input.status === "running" + ? { + status: "running" as const, + input: { questions }, + time: { start: now }, + } + : input.status === "completed" + ? { + status: "completed" as const, + input: { questions }, + output: "User answered: A", + title: "Asked 1 question", + metadata: { answers: [["A"]] }, + time: { start: now, end: now }, + } + : { + status: "error" as const, + input: { questions }, + error: "Tool execution aborted", + time: { start: now, end: now }, + } + + Database.use((db) => { + db.insert(PartTable) + .values({ + id: partID, + message_id: input.messageID, + session_id: input.sessionID, + time_created: now, + time_updated: now, + data: { + type: "tool", + callID: `call_${partID}`, + tool: "question", + state, + } as typeof PartTable.$inferInsert.data, + }) + .run() + }) + + return partID +} + +test("recover - finds running question parts and lists them", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + insertQuestionPart({ sessionID, messageID, status: "running" }) + + const pending = await Question.list() + expect(pending.length).toBe(1) + expect(pending[0].sessionID).toBe(sessionID) + expect(pending[0].questions[0].question).toBe("Pick something?") + }, + }) +}) + +test("recover - finds aborted question parts and resets to running", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + const partID = insertQuestionPart({ sessionID, messageID, status: "aborted" }) + + const pending = await Question.list() + expect(pending.length).toBe(1) + + const row = Database.use((db) => db.select().from(PartTable).where(eq(PartTable.id, partID)).get()) + const data = row!.data as MessageV2.ToolPart + expect(data.state.status).toBe("running") + }, + }) +}) + +test("recover - runs only once per instance", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + insertQuestionPart({ sessionID, messageID, status: "running" }) + + const first = await Question.list() + expect(first.length).toBe(1) + + // Insert another part after recovery already ran + insertQuestionPart({ sessionID, messageID, status: "running" }) + + const second = await Question.list() + // Should still be 1 because recover() doesn't re-run + expect(second.length).toBe(1) + }, + }) +}) + +test("recover - ignores parts from different directories", async () => { + await using tmp = await tmpdir({ git: true }) + await using other = await tmpdir({ git: true }) + + // Seed a session in the "other" directory + await Instance.provide({ + directory: other.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: other.path, + }) + insertQuestionPart({ sessionID, messageID, status: "running" }) + }, + }) + + // Now list from "tmp" — should not see the other directory's question + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const pending = await Question.list() + expect(pending.length).toBe(0) + }, + }) +}) + +test("recover - reply updates DB part to completed", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + const partID = insertQuestionPart({ sessionID, messageID, status: "running" }) + + const pending = await Question.list() + expect(pending.length).toBe(1) + + await Question.reply({ + requestID: pending[0].id, + answers: [["A"]], + }) + + const row = Database.use((db) => db.select().from(PartTable).where(eq(PartTable.id, partID)).get()) + const data = row!.data as MessageV2.ToolPart + expect(data.state.status).toBe("completed") + if (data.state.status === "completed") { + expect(data.state.output).toContain("A") + expect(data.state.metadata.answers).toEqual([["A"]]) + } + + const after = await Question.list() + expect(after.length).toBe(0) + }, + }) +}) + +test("recover - reject updates DB part to error", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + const partID = insertQuestionPart({ sessionID, messageID, status: "running" }) + + const pending = await Question.list() + expect(pending.length).toBe(1) + + await Question.reject(pending[0].id) + + const row = Database.use((db) => db.select().from(PartTable).where(eq(PartTable.id, partID)).get()) + const data = row!.data as MessageV2.ToolPart + expect(data.state.status).toBe("error") + if (data.state.status === "error") { + expect(data.state.error).toBe("The user dismissed this question") + } + + const after = await Question.list() + expect(after.length).toBe(0) + }, + }) +}) + +test("recover - ignores already completed question parts", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const sessionID = Identifier.ascending("session") + const messageID = seed({ + projectID: Instance.project.id, + sessionID, + directory: tmp.path, + }) + insertQuestionPart({ sessionID, messageID, status: "completed" }) + + const pending = await Question.list() + expect(pending.length).toBe(0) + }, + }) +})