Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions packages/opencode/src/cli/cmd/tui/context/sync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,18 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
sdk.client.provider.auth().then((x) => setStore("provider_auth", reconcile(x.data ?? {}))),
sdk.client.vcs.get().then((x) => setStore("vcs", reconcile(x.data))),
sdk.client.path.get().then((x) => setStore("path", reconcile(x.data!))),
sdk.client.question.list().then((x) => {
if (x.data) {
const grouped: Record<string, QuestionRequest[]> = {}
for (const q of x.data) {
if (!grouped[q.sessionID]) grouped[q.sessionID] = []
grouped[q.sessionID].push(q)
}
for (const [id, qs] of Object.entries(grouped)) {
setStore("question", id, reconcile(qs.sort((a, b) => a.id.localeCompare(b.id))))
}
}
}),
]).then(() => {
setStore("status", "complete")
})
Expand Down
161 changes: 159 additions & 2 deletions packages/opencode/src/question/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ import { BusEvent } from "@/bus/bus-event"
import { Identifier } from "@/id/id"
import { Instance } from "@/project/instance"
import { Log } from "@/util/log"
import { Database, eq } from "@/storage/db"
import { PartTable, SessionTable } from "../session/session.sql"
import { MessageV2 } from "../session/message-v2"
import { Filesystem } from "../util/filesystem"
import z from "zod"

export namespace Question {
Expand Down Expand Up @@ -85,15 +89,102 @@ export namespace Question {
{
info: Request
resolve: (answers: Answer[]) => void
reject: (e: any) => void
reject: (e: Error) => void
recovered?: { partID: string }
}
> = {}

return {
pending,
didRecover: false,
}
})

async function recover() {
const s = await state()
if (s.didRecover) return
s.didRecover = true

const allParts = Database.use((db) =>
db
.select({
id: PartTable.id,
messageID: PartTable.message_id,
sessionID: PartTable.session_id,
directory: SessionTable.directory,
data: PartTable.data,
})
.from(PartTable)
.innerJoin(SessionTable, eq(PartTable.session_id, SessionTable.id))
.all(),
)

const candidates = allParts.filter((p) => {
const data = p.data as MessageV2.Part
if (data.type !== "tool" || data.tool !== "question") return false
const isRunning = data.state.status === "running"
const isAborted = data.state.status === "error" && data.state.error === "Tool execution aborted"
return isRunning || isAborted
})

const dir = Instance.directory
const recoverable = candidates.filter((q) => {
const id = q.id.replace("prt_", "que_")
if (s.pending[id]) return false
return Filesystem.overlaps(q.directory, dir)
})

if (recoverable.length === 0) return

log.info("recovering questions", { count: recoverable.length })
for (const row of recoverable) {
const data = row.data as MessageV2.ToolPart
const id = row.id.replace("prt_", "que_")

const info: Request = {
id,
sessionID: row.sessionID,
questions: (data.state.input as { questions: Info[] }).questions,
tool: {
messageID: row.messageID,
callID: data.callID,
},
}
s.pending[id] = {
info,
resolve: () => {},
reject: () => {},
recovered: { partID: row.id },
}

if (data.state.status === "error") {
data.state = {
status: "running",
input: data.state.input,
time: {
start: data.state.time.start,
},
}
Database.use((db) => {
db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run()
Bus.publish(MessageV2.Event.PartUpdated, {
part: { ...data, id: row.id, messageID: row.messageID, sessionID: row.sessionID } satisfies MessageV2.Part,
})
})
}

await Bus.publish(Event.Asked, info)
}
}

export function format(questions: Info[], answers: Answer[]) {
const fmt = (answer: Answer | undefined) => {
if (!answer?.length) return "Unanswered"
return answer.join(", ")
}
return questions.map((q, i) => `"${q.question}"="${fmt(answers[i])}"`).join(", ")
}

export async function ask(input: {
sessionID: string
questions: Info[]
Expand Down Expand Up @@ -137,6 +228,43 @@ export namespace Question {
answers: input.answers,
})

if (existing.recovered && existing.info.tool) {
const partID = existing.recovered.partID
const list = existing.info.questions
const formatted = format(list, input.answers)
const output = `User has answered your questions: ${formatted}. You can now continue with the user's answers in mind.`
const title = `Asked ${list.length} question${list.length > 1 ? "s" : ""}`

Database.use((db) => {
const row = db.select().from(PartTable).where(eq(PartTable.id, partID)).get()
if (!row) return
const data = row.data as MessageV2.ToolPart
const state = data.state
const start = state.status === "pending" ? Date.now() : state.time.start
data.state = {
status: "completed",
input: state.input,
output,
title,
metadata: {
answers: input.answers,
},
time: {
start,
end: Date.now(),
},
}
const part = {
...data,
id: row.id,
messageID: row.message_id,
sessionID: row.session_id,
} satisfies MessageV2.Part
db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run()
Bus.publish(MessageV2.Event.PartUpdated, { part })
})
}

existing.resolve(input.answers)
}

Expand All @@ -151,11 +279,39 @@ export namespace Question {

log.info("rejected", { requestID })

Bus.publish(Event.Rejected, {
await Bus.publish(Event.Rejected, {
sessionID: existing.info.sessionID,
requestID: existing.info.id,
})

if (existing.recovered && existing.info.tool) {
const partID = existing.recovered.partID
Database.use((db) => {
const row = db.select().from(PartTable).where(eq(PartTable.id, partID)).get()
if (!row) return
const data = row.data as MessageV2.ToolPart
const state = data.state
const start = state.status === "pending" ? Date.now() : state.time.start
data.state = {
status: "error",
input: state.input,
error: "The user dismissed this question",
time: {
start,
end: Date.now(),
},
}
const part = {
...data,
id: row.id,
messageID: row.message_id,
sessionID: row.session_id,
} satisfies MessageV2.Part
db.update(PartTable).set({ data }).where(eq(PartTable.id, row.id)).run()
Bus.publish(MessageV2.Event.PartUpdated, { part })
})
}

existing.reject(new RejectedError())
}

Expand All @@ -166,6 +322,7 @@ export namespace Question {
}

export async function list() {
await recover()
return state().then((x) => Object.values(x.pending).map((x) => x.info))
}
}
7 changes: 1 addition & 6 deletions packages/opencode/src/tool/question.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ export const QuestionTool = Tool.define("question", {
tool: ctx.callID ? { messageID: ctx.messageID, callID: ctx.callID } : undefined,
})

function format(answer: Question.Answer | undefined) {
if (!answer?.length) return "Unanswered"
return answer.join(", ")
}

const formatted = params.questions.map((q, i) => `"${q.question}"="${format(answers[i])}"`).join(", ")
const formatted = Question.format(params.questions, answers)

return {
title: `Asked ${params.questions.length} question${params.questions.length > 1 ? "s" : ""}`,
Expand Down
Loading
Loading