diff --git a/lua/avante/ui/acp_confirm_adapter.lua b/lua/avante/acp/acp_confirm_adapter.lua similarity index 100% rename from lua/avante/ui/acp_confirm_adapter.lua rename to lua/avante/acp/acp_confirm_adapter.lua diff --git a/lua/avante/acp/acp_diff_handler.lua b/lua/avante/acp/acp_diff_handler.lua new file mode 100644 index 000000000..1f449e26f --- /dev/null +++ b/lua/avante/acp/acp_diff_handler.lua @@ -0,0 +1,370 @@ +local P = {} + +---@class avante.ACPDiffHandler +local M = {} + +local Utils = require("avante.utils") +local Config = require("avante.config") + +---ACP handler to check if tool call contains diff content and display them in the buffer +---@param tool_call avante.acp.ToolCallUpdate +---@return boolean has_diff +function M.has_diff_content(tool_call) + for _, content_item in ipairs(tool_call.content or {}) do + if content_item.type == "diff" then return true end + end + + local raw = tool_call.rawInput + if not raw then return false end + + local has_new = (raw.new_string ~= nil and raw.new_string ~= vim.NIL) + return has_new +end + +--- Extract diff blocks from ACP tool call content +--- +--- IMPORTANT ASSUMPTION: rawInput and content always reference the same file(s). +--- If rawInput exists with a file path, the content array will reference the same file(s). +--- This means we can safely skip processing the content array when rawInput.replace_all=true, +--- as they represent the same operation on the same file(s). +--- +--- @param tool_call avante.acp.ToolCallUpdate +--- @return table diff_blocks_by_file Maps file path to list of diff blocks +function M.extract_diff_blocks(tool_call) + --- @type table + local diff_blocks_by_file = {} + + -- PRIORITY: If rawInput exists with replace_all=true, process it even if content exists, + -- because the content array cannot express replace_all semantics. + local raw = tool_call.rawInput + local should_use_raw_input = raw and raw.replace_all == true + -- Note: rawInput and content array reference the same file(s), so skipping content array is safe. + + -- `content` doesn't support replace_all semantics, it could generate false-positives when replacing the same string multiple times. + if not should_use_raw_input then + -- Handle content array (standard) + for _, content_item in ipairs(tool_call.content or {}) do + if content_item.type == "diff" then + local path = content_item.path + local oldText = content_item.oldText + local newText = content_item.newText + + if oldText == "" or oldText == vim.NIL or oldText == nil then + -- New file case + local new_lines = P._normalize_text_to_lines(newText) + local diff_block = P._create_new_file_diff_block(new_lines) + P._add_diff_block(diff_blocks_by_file, path, diff_block) + else + -- Existing file case + local old_lines = P._normalize_text_to_lines(oldText) + local new_lines = P._normalize_text_to_lines(newText) + + local abs_path = Utils.to_absolute_path(path) + local file_lines = Utils.read_file_from_buf_or_disk(abs_path) or {} + local start_line, end_line = Utils.fuzzy_match(file_lines, old_lines) + + if not start_line or not end_line then + -- Fallback: if oldText is a single word/line and exact match failed, + -- try substring matching within lines (but NOT replace_all - that requires rawInput) + -- This handles cases where the text is part of a longer line + -- NOTE: content array represents a SINGLE replacement, not replace_all + if #old_lines == 1 and #new_lines == 1 then + local search_text = old_lines[1] + local replace_text = new_lines[1] + local found_blocks = P._find_substring_replacements(file_lines, search_text, replace_text, false) + + if #found_blocks > 0 then + for _, block in ipairs(found_blocks) do + P._add_diff_block(diff_blocks_by_file, path, block) + end + else + Utils.debug( + "[ACP diff content] Failed to find location for diff in file (tried substring matching): ", + { + path = path, + oldText = oldText, + newText = newText, + i = _, + content_item = content_item, + tool_call = tool_call, + } + ) + end + else + Utils.debug("[ACP diff content] Failed to find location for diff in file: ", { + path = path, + oldText = oldText, + newText = newText, + i = _, + content_item = content_item, + tool_call = tool_call, + }) + end + else + local diff_block = { + start_line = start_line, + end_line = end_line, + old_lines = old_lines, + new_lines = new_lines, + } + P._add_diff_block(diff_blocks_by_file, path, diff_block) + end + end + end + end + end + + local has_diff_blocks = not P._is_table_empty(diff_blocks_by_file) + + -- Use rawInput if no diff blocks found from content array OR replace_all is true + if raw and (should_use_raw_input or not has_diff_blocks) then + Utils.debug("[ACP diff] Processing rawInput", { + tool_call = tool_call, + reason = raw.replace_all and "replace_all semantics" or "fallback", + }) + + local file_path = raw.file_path + local old_string = raw.old_string == vim.NIL and nil or raw.old_string + local new_string = raw.new_string == vim.NIL and nil or raw.new_string + + if file_path and new_string then + local old_lines = P._normalize_text_to_lines(old_string) + local new_lines = P._normalize_text_to_lines(new_string) + + local abs_path = Utils.to_absolute_path(file_path) + local file_lines = Utils.read_file_from_buf_or_disk(abs_path) or {} + + if #old_lines == 0 or (#old_lines == 1 and old_lines[1] == "") then + -- New file case + local diff_block = P._create_new_file_diff_block(new_lines) + diff_blocks_by_file[file_path] = { diff_block } + else + local replace_all = raw.replace_all + + if replace_all then + if #old_lines == 1 and #new_lines == 1 then + local search_text = old_lines[1] + local replace_text = new_lines[1] + local found_blocks = P._find_substring_replacements(file_lines, search_text, replace_text, true) + + if #found_blocks > 0 then + diff_blocks_by_file[file_path] = found_blocks + else + Utils.debug("[ACP diff rawInput] [replace_all] Failed to find substring", { + file_path = file_path, + old_string = old_string, + new_string = new_string, + raw = raw, + }) + end + else + -- Multi-line replace_all: use line matching + local matches = Utils.find_all_matches(file_lines, old_lines) + + if #matches == 0 then + Utils.debug("[ACP diff rawInput] [replace_all] Failed to find any matches for replace_all in file: ", { + file_path = file_path, + old_string = old_string, + new_string = new_string, + raw = raw, + }) + else + diff_blocks_by_file[file_path] = {} + + for _, match in ipairs(matches) do + P._add_diff_block(diff_blocks_by_file, file_path, { + start_line = match.start_line, + end_line = match.end_line, + old_lines = old_lines, + new_lines = new_lines, + }) + end + end + end + else + local start_line, end_line = Utils.fuzzy_match(file_lines, old_lines) + + if not start_line or not end_line then + -- Fallback: try substring replacement for single-line case + if #old_lines == 1 and #new_lines == 1 then + local search_text = old_lines[1] + local replace_text = new_lines[1] + local found_blocks = P._find_substring_replacements(file_lines, search_text, replace_text, false) + + if #found_blocks > 0 then + diff_blocks_by_file[file_path] = found_blocks + else + Utils.debug("[ACP diff rawInput] Failed to find location for diff in file: ", { + file_path = file_path, + old_string = old_string, + new_string = new_string, + raw = raw, + }) + end + else + Utils.debug("[ACP diff rawInput] Failed to find location for diff in file: ", { + file_path = file_path, + old_string = old_string, + new_string = new_string, + raw = raw, + }) + end + else + local diff_block = { + start_line = start_line, + end_line = end_line, + old_lines = old_lines, + new_lines = new_lines, + } + diff_blocks_by_file[file_path] = { diff_block } + end + end + end + end + end + + for path, diff_blocks in pairs(diff_blocks_by_file) do + -- Sort by start_line to handle multiple diffs correctly + table.sort(diff_blocks, function(a, b) return a.start_line < b.start_line end) + + -- Apply minimize_diff if enabled + if Config.behaviour.minimize_diff then + diff_blocks = P.minimize_diff_blocks(diff_blocks) + diff_blocks_by_file[path] = diff_blocks + end + end + + if P._is_table_empty(diff_blocks_by_file) then + Utils.debug("[ACP diff] No diff blocks extracted from tool call", { + tool_call = tool_call, + }) + end + + return diff_blocks_by_file +end + +---Minimize diff blocks by removing unchanged lines (similar to replace_in_file.lua) +---@param diff_blocks avante.DiffBlock[] +---@return avante.DiffBlock[] +function P.minimize_diff_blocks(diff_blocks) + local minimized = {} + for _, diff_block in ipairs(diff_blocks) do + local old_string = table.concat(diff_block.old_lines, "\n") + local new_string = table.concat(diff_block.new_lines, "\n") + + ---@type integer[][] + ---@diagnostic disable-next-line: assign-type-mismatch + local patch = vim.diff(old_string, new_string, { + algorithm = "histogram", + result_type = "indices", + ctxlen = 0, + }) + + if #patch > 0 then + for _, hunk in ipairs(patch) do + local start_a, count_a, start_b, count_b = unpack(hunk) + local minimized_block = {} + if count_a > 0 then + local end_a = math.min(start_a + count_a - 1, #diff_block.old_lines) + minimized_block.old_lines = vim.list_slice(diff_block.old_lines, start_a, end_a) + else + minimized_block.old_lines = {} + end + if count_b > 0 then + local end_b = math.min(start_b + count_b - 1, #diff_block.new_lines) + minimized_block.new_lines = vim.list_slice(diff_block.new_lines, start_b, end_b) + else + minimized_block.new_lines = {} + end + if count_a > 0 then + minimized_block.start_line = diff_block.start_line + start_a - 1 + minimized_block.end_line = minimized_block.start_line + count_a - 1 + else + -- For insertions, start_line is the position before which to insert + minimized_block.start_line = diff_block.start_line + start_a + minimized_block.end_line = minimized_block.start_line - 1 + end + table.insert(minimized, minimized_block) + end + end + end + + table.sort(minimized, function(a, b) return a.start_line < b.start_line end) + + return minimized +end + +---Create a diff block for a new file +---@param new_lines string[] +---@return avante.DiffBlock +function P._create_new_file_diff_block(new_lines) + return { + start_line = 1, + end_line = 0, + old_lines = {}, + new_lines = new_lines, + } +end + +---Normalize text to lines array, handling nil and vim.NIL +---@param text string|nil +---@return string[] +function P._normalize_text_to_lines(text) + if not text or text == vim.NIL or text == "" then return {} end + return type(text) == "string" and vim.split(text, "\n") or {} +end + +---Add a diff block to the collection, ensuring the path array exists +---@param diff_blocks_by_file table +---@param path string +---@param diff_block avante.DiffBlock +function P._add_diff_block(diff_blocks_by_file, path, diff_block) + diff_blocks_by_file[path] = diff_blocks_by_file[path] or {} + table.insert(diff_blocks_by_file[path], diff_block) +end + +---Find and replace substring occurrences in file lines +---@param file_lines string[] File content lines +---@param search_text string Text to search for +---@param replace_text string Text to replace with +---@param replace_all boolean If true, replace all occurrences; if false, only first match +---@return avante.DiffBlock[] Array of diff blocks created +function P._find_substring_replacements(file_lines, search_text, replace_text, replace_all) + local diff_blocks = {} + local escaped_search = search_text:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%1") + + for line_idx, line_content in ipairs(file_lines) do + if line_content:find(search_text, 1, true) then + local modified_line + if replace_all then + -- Replace all occurrences in this line + -- Use function replacement to avoid pattern interpretation of replace_text + -- This ensures literal replacement (e.g., "result%1" stays as "result%1", not backreference) + modified_line = line_content:gsub(escaped_search, function() return replace_text end) + else + -- Replace first occurrence only + -- Use function replacement to ensure literal text (no pattern interpretation) + modified_line = line_content:gsub(escaped_search, function() return replace_text end, 1) + end + + table.insert(diff_blocks, { + start_line = line_idx, + end_line = line_idx, + old_lines = { line_content }, + new_lines = { modified_line }, + }) + + -- For single replacement mode, stop after first match + if not replace_all then break end + end + end + + return diff_blocks +end + +---Check if a table is empty (has no keys) +---@param tbl table +---@return boolean +function P._is_table_empty(tbl) return next(tbl) == nil end + +return M diff --git a/lua/avante/acp/acp_diff_preview.lua b/lua/avante/acp/acp_diff_preview.lua new file mode 100644 index 000000000..99a62eb14 --- /dev/null +++ b/lua/avante/acp/acp_diff_preview.lua @@ -0,0 +1,91 @@ +local api = vim.api + +local Utils = require("avante.utils") +local Config = require("avante.config") +local DiffDisplay = require("avante.utils.diff_display") +local ACPDiffHandler = require("avante.acp.acp_diff_handler") +local LLMToolHelpers = require("avante.llm_tools.helpers") + +---@class avante.ACPDiffPreviewState +---@field bufnr integer +---@field path string +---@field lines string[] Original buffer lines +---@field changedtick integer Original changedtick +---@field modified boolean Original modified flag +---@field modifiable boolean Original modifiable flag +---@field diff_display avante.DiffDisplayInstance + +---@class avante.ACPDiffPreviewOpts +---@field tool_call avante.acp.ToolCallUpdate The ACP tool call containing diff content +---@field session_ctx? table Session context (for auto-approval checks) + +---@class avante.ui.acp_diff_preview +local M = {} + +---Show diff preview for ACP tool call +---Returns a cleanup function that is safe to call in all cases (accept/reject/disabled) +---@param opts avante.ACPDiffPreviewOpts +---@return fun() cleanup Cleanup function - safe to call multiple times +function M.show_acp_diff(opts) + local should_skip = not Config.behaviour.acp_show_diff_in_buffer + or LLMToolHelpers.is_auto_approved(opts.session_ctx, opts.tool_call.kind) + or not ACPDiffHandler.has_diff_content(opts.tool_call) + + if should_skip then + return function() end + end + + local diffs = ACPDiffHandler.extract_diff_blocks(opts.tool_call) + + ---@type avante.ACPDiffPreviewState[] + local preview_states = {} + + for path, diff_blocks in pairs(diffs) do + local abs_path = Utils.to_absolute_path(path) + local bufnr = vim.fn.bufnr(abs_path) + if bufnr == -1 then bufnr = vim.fn.bufnr(abs_path, true) end + + local diff_display = DiffDisplay.new({ + bufnr = bufnr, + diff_blocks = diff_blocks, + }) + + local ok_changedtick, changedtick = pcall(function() return vim.b[bufnr].changedtick end) + + local state = { + bufnr = bufnr, + path = path, + lines = api.nvim_buf_get_lines(bufnr, 0, -1, false), + changedtick = ok_changedtick and changedtick or 0, + modified = vim.bo[bufnr].modified, + modifiable = vim.bo[bufnr].modifiable, + diff_display = diff_display, + } + + diff_display:highlight() + diff_display:scroll_to_first_diff() + diff_display:register_cursor_move_events() + diff_display:register_navigation_keybindings() + + vim.bo[bufnr].modifiable = false + + table.insert(preview_states, state) + end + + -- Cleanup function to clear diff display and restore buffer flags + return function() + if not preview_states or #preview_states == 0 then return end + + for _, state in ipairs(preview_states) do + if state.diff_display then state.diff_display:clear() end + + -- Restore buffer flags if buffer is still valid + if api.nvim_buf_is_valid(state.bufnr) then vim.bo[state.bufnr].modifiable = state.modifiable end + end + + -- Clear references to help garbage collection + preview_states = {} + end +end + +return M diff --git a/lua/avante/config.lua b/lua/avante/config.lua index aa4336337..ef1b0e99f 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -552,9 +552,7 @@ M._defaults = { --- inline_buttons is the new inline buttons in the sidebar ---@type "popup" | "inline_buttons" confirmation_ui_style = "inline_buttons", - --- Whether to automatically open files and navigate to lines when ACP agent makes edits - ---@type boolean - acp_follow_agent_locations = true, + acp_show_diff_in_buffer = true, }, prompt_logger = { -- logs prompts to disk (timestamped, for replay/debugging) enabled = true, -- toggle logging entirely diff --git a/lua/avante/highlights.lua b/lua/avante/highlights.lua index b061f6c3d..b3dc59567 100644 --- a/lua/avante/highlights.lua +++ b/lua/avante/highlights.lua @@ -75,6 +75,11 @@ Highlights.conflict = { INCOMING_LABEL = { name = "AvanteConflictIncomingLabel", shade_link = "AvanteConflictIncoming", shade = 30 }, } +Highlights.DIFF_INCOMING = { name = "AvanteDiffIncoming", bg = "#2d4a2c", bold = true } -- Green for additions +Highlights.DIFF_INCOMING_WORD = { name = "AvanteDiffIncomingWord", bg = "#0d7a4d", bold = true } -- Much darker/brighter green for changed words +Highlights.DIFF_DELETED = { name = "AvanteDiffDeleted", bg = "#562C30" } -- Red for deletions (virtual text) +Highlights.DIFF_DELETED_WORD = { name = "AvanteDiffDeletedWord", bg = "#9a3c3c", bold = true } -- Darker/brighter red for deleted words + --- helper local H = {} diff --git a/lua/avante/history/helpers.lua b/lua/avante/history/helpers.lua index 48aaa00a4..c5c3f49ec 100644 --- a/lua/avante/history/helpers.lua +++ b/lua/avante/history/helpers.lua @@ -1,5 +1,3 @@ -local Utils = require("avante.utils") - local M = {} ---If message is a text message return the text. diff --git a/lua/avante/history/message.lua b/lua/avante/history/message.lua index eea5c79b9..c6340111c 100644 --- a/lua/avante/history/message.lua +++ b/lua/avante/history/message.lua @@ -1,6 +1,12 @@ local Utils = require("avante.utils") ----@class avante.HistoryMessage +---@class avante.history.Message +---@field message AvanteLLMMessage +---@field uuid string +---@field state string +---@field timestamp number +---@field is_user_submission boolean +---@field visible boolean local M = {} M.__index = M @@ -21,7 +27,7 @@ M.__index = M ---@param role "user" | "assistant" ---@param content AvanteLLMMessageContentItem ---@param opts? avante.HistoryMessage.Opts ----@return avante.HistoryMessage +---@return avante.history.Message function M:new(role, content, opts) ---@type AvanteLLMMessage local message = { role = role, content = type(content) == "string" and content or { content } } @@ -40,17 +46,17 @@ end ---Creates a new instance of synthetic (dummy) history message ---@param role "assistant" | "user" ---@param item AvanteLLMMessageContentItem ----@return avante.HistoryMessage +---@return avante.history.Message function M:new_synthetic(role, item) return M:new(role, item, { is_dummy = true }) end ---Creates a new instance of synthetic (dummy) history message attributed to the assistant ---@param item AvanteLLMMessageContentItem ----@return avante.HistoryMessage +---@return avante.history.Message function M:new_assistant_synthetic(item) return M:new_synthetic("assistant", item) end ---Creates a new instance of synthetic (dummy) history message attributed to the user ---@param item AvanteLLMMessageContentItem ----@return avante.HistoryMessage +---@return avante.history.Message function M:new_user_synthetic(item) return M:new_synthetic("user", item) end ---Updates content of a message as long as it is a simple text (or empty). diff --git a/lua/avante/history/render.lua b/lua/avante/history/render.lua index 89615f630..aebac7a01 100644 --- a/lua/avante/history/render.lua +++ b/lua/avante/history/render.lua @@ -138,7 +138,7 @@ function M.get_diff_lines(old_str, new_str, decoration, truncate) local patch = vim.diff(old_str, new_str, { ---@type integer[][] algorithm = "histogram", result_type = "indices", - ctxlen = vim.o.scrolloff, + ctxlen = 0, }) local prev_start_a = 0 local truncated_lines = 0 @@ -314,20 +314,10 @@ function M.get_content_lines(content, decoration, truncate) table.insert(lines, line) end end - elseif - content_item.type == "diff" - and content_item.oldText ~= nil - and content_item.newText ~= nil - and content_item.oldText ~= vim.NIL - and content_item.newText ~= vim.NIL - then - local relative_path = Utils.relative_path(content_item.path) - table.insert(lines, Line:new({ { decoration }, { "Path: " .. relative_path } })) - local lines_ = M.get_diff_lines(content_item.oldText, content_item.newText, decoration, truncate) - lines = vim.list_extend(lines, lines_) end end end + return lines end @@ -342,16 +332,17 @@ function M.get_tool_display_name(message) if not islist(content) then return "", "expected message content to be a list" end - local item = message.message.content[1] + local item = content[1] + + local native_tool_name = item and item.name - local native_tool_name = item.name if native_tool_name == "other" and message.acp_tool_call then - native_tool_name = message.acp_tool_call.title or "Other" + native_tool_name = message.acp_tool_call.kind or "other" end - if message.acp_tool_call and message.acp_tool_call.title then native_tool_name = message.acp_tool_call.title end - local tool_name = native_tool_name + + local tool_name if message.displayed_tool_name then - tool_name = message.displayed_tool_name + tool_name = message.displayed_tool_name or "" else local param if item.input and type(item.input) == "table" then @@ -361,20 +352,24 @@ function M.get_tool_display_name(message) if type(item.input.filepath) == "string" then path = item.input.filepath end if type(item.input.file_path) == "string" then path = item.input.file_path end if type(item.input.query) == "string" then param = item.input.query end + if type(item.input.url) == "string" then param = item.input.url end if type(item.input.pattern) == "string" then param = item.input.pattern end if type(item.input.command) == "string" then param = item.input.command local pieces = vim.split(param, "\n") if #pieces > 1 then param = pieces[1] .. "..." end end + if native_tool_name == "execute" and not param then if message.acp_tool_call and message.acp_tool_call.title then param = message.acp_tool_call.title end end + if not param and path then local relative_path = Utils.relative_path(path) param = relative_path end end + if not param and message.acp_tool_call then if message.acp_tool_call.locations then for _, location in ipairs(message.acp_tool_call.locations) do @@ -386,22 +381,24 @@ function M.get_tool_display_name(message) end end end + if not param and message.acp_tool_call and message.acp_tool_call.rawInput and message.acp_tool_call.rawInput.command then - param = message.acp_tool_call.rawInput.command + param = message.acp_tool_call.rawInput.command or "" pcall(function() local project_root = Utils.root.get() param = param:gsub(project_root .. "/?", "") end) end - if param then tool_name = native_tool_name .. "(" .. vim.inspect(param) .. ")" end - end - ---@cast tool_name string + ---@diagnostic disable-next-line: param-type-mismatch + param = type(param) == "string" and param or table.concat(param or {}, " ") + tool_name = native_tool_name .. "(" .. param .. ")" + end return tool_name, nil end @@ -477,11 +474,11 @@ local function tool_to_lines(item, message, messages, expanded) not add_diff_lines and message.acp_tool_call and message.acp_tool_call.rawInput - and message.acp_tool_call.rawInput.oldString + and message.acp_tool_call.rawInput.old_string then local diff_lines = M.get_diff_lines( - message.acp_tool_call.rawInput.oldString, - message.acp_tool_call.rawInput.newString, + message.acp_tool_call.rawInput.old_string, + message.acp_tool_call.rawInput.new_string, decoration, not expanded ) diff --git a/lua/avante/libs/acp_client.lua b/lua/avante/libs/acp_client.lua index f842c5bd4..8bbe43a38 100644 --- a/lua/avante/libs/acp_client.lua +++ b/lua/avante/libs/acp_client.lua @@ -3,6 +3,7 @@ local Utils = require("avante.utils") ---@class avante.acp.ClientCapabilities ---@field fs avante.acp.FileSystemCapability +---@field terminal boolean ---@class avante.acp.FileSystemCapability ---@field readTextFile boolean @@ -87,15 +88,16 @@ local Utils = require("avante.utils") ---@alias ACPContent avante.acp.TextContent | avante.acp.ImageContent | avante.acp.AudioContent | avante.acp.ResourceLinkContent | avante.acp.ResourceContent ----@class avante.acp.ToolCall ----@field toolCallId string ----@field title string ----@field kind ACPToolKind ----@field status ACPToolCallStatus ----@field content ACPToolCallContent[] ----@field locations avante.acp.ToolCallLocation[] ----@field rawInput table ----@field rawOutput table +---@class ACPRawInput +---@field file_path string +---@field new_string? string +---@field old_string? string +---@field replace_all? boolean +---@field description? string +---@field command? string +---@field url? string Usually from the fetch tool +---@field query? string Usually from the web_search tool +---@field timeout? number ---@class avante.acp.BaseToolCallContent ---@field type "content" | "diff" @@ -107,7 +109,7 @@ local Utils = require("avante.utils") ---@class avante.acp.ToolCallDiffContent : avante.acp.BaseToolCallContent ---@field type "diff" ---@field path string ----@field oldText string|nil +---@field oldText string ---@field newText string ---@alias ACPToolCallContent avante.acp.ToolCallRegularContent | avante.acp.ToolCallDiffContent @@ -144,16 +146,16 @@ local Utils = require("avante.utils") ---@field sessionUpdate "agent_thought_chunk" ---@field content ACPContent ----@class avante.acp.ToolCallUpdate : avante.acp.BaseSessionUpdate ----@field sessionUpdate "tool_call" | "tool_call_update" +---@class avante.acp.ToolCallUpdate +---@field sessionUpdate? "tool_call" | "tool_call_update" ---@field toolCallId string ----@field title string|nil ----@field kind ACPToolKind|nil ----@field status ACPToolCallStatus|nil ----@field content ACPToolCallContent[]|nil ----@field locations avante.acp.ToolCallLocation[]|nil ----@field rawInput table|nil ----@field rawOutput table|nil +---@field title? string +---@field kind? ACPToolKind +---@field status? ACPToolCallStatus +---@field content? ACPToolCallContent[] +---@field locations? avante.acp.ToolCallLocation[] +---@field rawInput? ACPRawInput +---@field rawOutput? table ---@class avante.acp.PlanUpdate : avante.acp.BaseSessionUpdate ---@field sessionUpdate "plan" @@ -168,6 +170,11 @@ local Utils = require("avante.utils") ---@field name string ---@field kind "allow_once" | "allow_always" | "reject_once" | "reject_always" +---@class avante.acp.RequestPermission +---@field options avante.acp.PermissionOption[] +---@field sessionId string +---@field toolCall { toolCallId: string, rawInput: ACPRawInput|nil } + ---@class avante.acp.RequestPermissionOutcome ---@field outcome "cancelled" | "selected" ---@field optionId string|nil @@ -206,7 +213,7 @@ ACPClient.ERROR_CODES = { ---@class ACPHandlers ---@field on_session_update? fun(update: avante.acp.UserMessageChunk | avante.acp.AgentMessageChunk | avante.acp.AgentThoughtChunk | avante.acp.ToolCallUpdate | avante.acp.PlanUpdate | avante.acp.AvailableCommandsUpdate) ----@field on_request_permission? fun(tool_call: table, options: table[], callback: fun(option_id: string | nil)): nil +---@field on_request_permission? fun(request: avante.acp.RequestPermission, callback: fun(option_id: string | nil)): nil ---@field on_read_file? fun(path: string, line: integer | nil, limit: integer | nil, callback: fun(content: string)): nil ---@field on_write_file? fun(path: string, content: string, callback: fun(error: string|nil)): nil ---@field on_error? fun(error: table) @@ -240,6 +247,7 @@ function ACPClient:new(config) readTextFile = true, writeTextFile = true, }, + terminal = false, }, debug_log_file = debug_log_file, pending_responses = {}, @@ -625,19 +633,14 @@ end ---Handle permission request notification ---@param message_id number ----@param params table -function ACPClient:_handle_request_permission(message_id, params) - local session_id = params.sessionId - local tool_call = params.toolCall - local options = params.options - - if not session_id or not tool_call then return end +---@param request avante.acp.RequestPermission +function ACPClient:_handle_request_permission(message_id, request) + if not request.sessionId or not request.toolCall then return end if self.config.handlers and self.config.handlers.on_request_permission then vim.schedule(function() self.config.handlers.on_request_permission( - tool_call, - options, + request, function(option_id) self:_send_result(message_id, { outcome = { diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 34b6c495d..cf5d7bb7f 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -15,7 +15,8 @@ local LLMToolHelpers = require("avante.llm_tools.helpers") local LLMTools = require("avante.llm_tools") local History = require("avante.history") local HistoryRender = require("avante.history.render") -local ACPConfirmAdapter = require("avante.ui.acp_confirm_adapter") +local ACPConfirmAdapter = require("avante.acp.acp_confirm_adapter") +local ACPDiffPreview = require("avante.acp.acp_diff_preview") ---@class avante.LLM local M = {} @@ -910,6 +911,7 @@ local function truncate_history_for_recovery(history_messages) return truncated end + ---@param opts AvanteLLMStreamOptions function M._stream_acp(opts) Utils.debug("use ACP", Config.provider) @@ -920,10 +922,12 @@ function M._stream_acp(opts) local acp_provider = Config.acp_providers[Config.provider] local prev_text_message_content = "" local history_messages = {} + local get_history_messages = function() if opts.get_history_messages then return opts.get_history_messages() end return history_messages end + local on_messages_add = function(messages) if opts.on_chunk then for _, message in ipairs(messages) do @@ -953,19 +957,25 @@ function M._stream_acp(opts) end end end + + ---@param update avante.acp.ToolCallUpdate local function add_tool_call_message(update) + local id = update.toolCallId + local message = History.Message:new("assistant", { + id = id, type = "tool_use", - id = update.toolCallId, - name = update.kind or update.title, + name = update.kind or update.title or "other", input = update.rawInput or {}, }, { - uuid = update.toolCallId, + uuid = id, }) + last_tool_call_message = message message.acp_tool_call = update if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end - tool_call_messages[update.toolCallId] = message + tool_call_messages[id] = message + if update.rawInput then local description = update.rawInput.description if description then @@ -973,10 +983,13 @@ function M._stream_acp(opts) table.insert(message.tool_use_logs, description) end end + on_messages_add({ message }) return message end + local acp_client = opts.acp_client + if not acp_client then local acp_config = vim.tbl_deep_extend("force", acp_provider, { ---@type ACPHandlers @@ -1063,73 +1076,7 @@ function M._stream_acp(opts) end end - if update.sessionUpdate == "tool_call" then - add_tool_call_message(update) - - local sidebar = require("avante").get() - - if - Config.behaviour.acp_follow_agent_locations - and sidebar - and not sidebar.is_in_full_view -- don't follow when in Zen mode - and update.kind == "edit" -- to avoid entering more than once - and update.locations - and #update.locations > 0 - then - vim.schedule(function() - if not sidebar:is_open() then return end - - -- Find a valid code window (non-sidebar window) - local code_winid = nil - if sidebar.code.winid and sidebar.code.winid ~= 0 and api.nvim_win_is_valid(sidebar.code.winid) then - code_winid = sidebar.code.winid - else - -- Find first non-sidebar window in the current tab - local all_wins = api.nvim_tabpage_list_wins(0) - for _, winid in ipairs(all_wins) do - if api.nvim_win_is_valid(winid) and not sidebar:is_sidebar_winid(winid) then - code_winid = winid - break - end - end - end - - if not code_winid then return end - - local now = uv.now() - local last_auto_nav = vim.g.avante_last_auto_nav or 0 - local grace_period = 2000 - - -- Check if user navigated manually recently - if now - last_auto_nav < grace_period then return end - - -- Only follow first location to avoid rapid jumping - local location = update.locations[1] - if not location or not location.path then return end - - local abs_path = Utils.join_paths(Utils.get_project_root(), location.path) - local bufnr = vim.fn.bufnr(abs_path, true) - - if not bufnr or bufnr == -1 then return end - - if not api.nvim_buf_is_loaded(bufnr) then pcall(vim.fn.bufload, bufnr) end - - local ok = pcall(api.nvim_win_set_buf, code_winid, bufnr) - if not ok then return end - - local line = location.line or 1 - local line_count = api.nvim_buf_line_count(bufnr) - local target_line = math.min(line, line_count) - - pcall(api.nvim_win_set_cursor, code_winid, { target_line, 0 }) - pcall(api.nvim_win_call, code_winid, function() - vim.cmd("normal! zz") -- Center line in viewport - end) - - vim.g.avante_last_auto_nav = now - end) - end - end + if update.sessionUpdate == "tool_call" then add_tool_call_message(update) end if update.sessionUpdate == "tool_call_update" then local tool_call_message = tool_call_messages[update.toolCallId] @@ -1196,30 +1143,35 @@ function M._stream_acp(opts) end end, - on_request_permission = function(tool_call, options, callback) + on_request_permission = function(request, callback) local sidebar = require("avante").get() if not sidebar then Utils.error("Avante sidebar not found") return end - ---@cast tool_call avante.acp.ToolCall - - local message = tool_call_messages[tool_call.toolCallId] + local message = tool_call_messages[request.toolCall.toolCallId] if not message then - message = add_tool_call_message(tool_call) + message = add_tool_call_message(request.toolCall) else if message.acp_tool_call then - if tool_call.content and next(tool_call.content) == nil then tool_call.content = nil end - message.acp_tool_call = vim.tbl_deep_extend("force", message.acp_tool_call, tool_call) + -- Merge updates into existing tool call message + message.acp_tool_call = vim.tbl_deep_extend("force", message.acp_tool_call, request.toolCall) end end on_messages_add({ message }) local description = HistoryRender.get_tool_display_name(message) + + local clear_diff_preview_safelly = ACPDiffPreview.show_acp_diff({ + tool_call = message.acp_tool_call, + session_ctx = opts.session_ctx, + }) + LLMToolHelpers.confirm(description, function(ok) - local acp_mapped_options = ACPConfirmAdapter.map_acp_options(options) + local acp_mapped_options = ACPConfirmAdapter.map_acp_options(request.options) + clear_diff_preview_safelly() if ok and opts.session_ctx and opts.session_ctx.always_yes then callback(acp_mapped_options.all) @@ -1235,9 +1187,10 @@ function M._stream_acp(opts) end, { focus = true, skip_reject_prompt = true, - permission_options = options, - }, opts.session_ctx, tool_call.kind) + permission_options = request.options, + }, opts.session_ctx, message.acp_tool_call.kind) end, + on_read_file = function(path, line, limit, callback) local abs_path = Utils.to_absolute_path(path) local lines = Utils.read_file_from_buf_or_disk(abs_path) @@ -1266,6 +1219,7 @@ function M._stream_acp(opts) end callback(content) end, + on_write_file = function(path, content, callback) local abs_path = Utils.to_absolute_path(path) local file = io.open(abs_path, "w") @@ -1280,8 +1234,13 @@ function M._stream_acp(opts) end, vim.api.nvim_list_bufs() ) + for _, buf in ipairs(buffers) do - vim.api.nvim_buf_call(buf, function() vim.cmd("edit") end) + vim.api.nvim_buf_call(buf, function() + local view = vim.fn.winsaveview() + vim.cmd("checktime") + vim.fn.winrestview(view) + end) end callback(nil) return @@ -1303,6 +1262,7 @@ function M._stream_acp(opts) if not acp_client.agent_capabilities.loadSession then opts.acp_session_id = nil end if opts.on_save_acp_client then opts.on_save_acp_client(acp_client) end end + local session_id = opts.acp_session_id if not session_id then local project_root = Utils.root.get() @@ -1318,9 +1278,11 @@ function M._stream_acp(opts) session_id = session_id_ if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end end + if opts.just_connect_acp_client then return end local prompt = {} local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0 + if donot_use_builtin_system_prompt then if opts.selected_filepaths then for _, filepath in ipairs(opts.selected_filepaths) do @@ -1342,11 +1304,12 @@ function M._stream_acp(opts) table.insert(prompt, prompt_item) end end - local history_messages = opts.history_messages or {} + + local messages_from_opt = opts.history_messages or {} -- DEBUG: Log history message details - Utils.debug("ACP history messages count: " .. #history_messages) - for i, msg in ipairs(history_messages) do + Utils.debug("ACP history messages count: " .. #messages_from_opt) + for i, msg in ipairs(messages_from_opt) do if msg and msg.message then Utils.debug( "History msg " @@ -1381,11 +1344,11 @@ function M._stream_acp(opts) local include_history_count = recovery_config.include_history_count or 15 -- Default to 15 for better context -- Get recent messages from truncated history - local start_idx = math.max(1, #history_messages - include_history_count + 1) - Utils.debug("Including history from index " .. start_idx .. " to " .. #history_messages) + local start_idx = math.max(1, #messages_from_opt - include_history_count + 1) + Utils.debug("Including history from index " .. start_idx .. " to " .. #messages_from_opt) - for i = start_idx, #history_messages do - local message = history_messages[i] + for i = start_idx, #messages_from_opt do + local message = messages_from_opt[i] if message and message.message then table.insert(recent_messages, message) Utils.debug("Adding message " .. i .. " to recent_messages: role=" .. (message.message.role or "unknown")) @@ -1453,8 +1416,8 @@ function M._stream_acp(opts) local include_history_count = recovery_config.include_history_count or 5 local user_messages_added = 0 - for i = #history_messages, 1, -1 do - local message = history_messages[i] + for i = #messages_from_opt, 1, -1 do + local message = messages_from_opt[i] if message.message.role == "user" and user_messages_added < include_history_count then local content = message.message.content if type(content) == "table" then @@ -1493,7 +1456,7 @@ function M._stream_acp(opts) else if donot_use_builtin_system_prompt then -- Include all user messages for better context preservation - for _, message in ipairs(history_messages) do + for _, message in ipairs(messages_from_opt) do if message.message.role == "user" then local content = message.message.content if type(content) == "table" then @@ -1534,6 +1497,7 @@ function M._stream_acp(opts) end end end + acp_client:send_prompt(session_id, prompt, function(_, err_) if err_ then -- ACP-specific session recovery: Check for session not found error @@ -1658,9 +1622,11 @@ function M._stream_acp(opts) -- CRITICAL: Return immediately to prevent further processing in fast event context return end + opts.on_stop({ reason = "error", error = err_ }) return end + opts.on_stop({ reason = "complete" }) end) end diff --git a/lua/avante/llm_tools/helpers.lua b/lua/avante/llm_tools/helpers.lua index aeffd2a95..3352d90f0 100644 --- a/lua/avante/llm_tools/helpers.lua +++ b/lua/avante/llm_tools/helpers.lua @@ -1,7 +1,7 @@ local Utils = require("avante.utils") local Path = require("plenary.path") local Config = require("avante.config") -local ACPConfirmAdapter = require("avante.ui.acp_confirm_adapter") +local ACPConfirmAdapter = require("avante.acp.acp_confirm_adapter") local M = {} @@ -46,30 +46,35 @@ function M.confirm_inline(callback, confirm_opts) end end ----@param message string ----@param callback fun(response: boolean, reason?: string) ----@param confirm_opts? avante.ui.ConfirmOptions ---@param session_ctx? table ---@param tool_name? string -- Optional tool name to check against tool_permissions config ----@return avante.ui.Confirm | nil -function M.confirm(message, callback, confirm_opts, session_ctx, tool_name) - callback = vim.schedule_wrap(callback) - if session_ctx and session_ctx.always_yes then - callback(true) - return - end +---@return boolean +function M.is_auto_approved(session_ctx, tool_name) + -- Check if session has always_yes flag set + if session_ctx and session_ctx.always_yes then return true end -- Check behaviour.auto_approve_tool_permissions config for auto-approval local auto_approve = Config.behaviour.auto_approve_tool_permissions -- If auto_approve is true, auto-approve all tools - if auto_approve == true then - callback(true) - return - end + if auto_approve == true then return true end -- If auto_approve is a table (array of tool names), check if this tool is in the list - if type(auto_approve) == "table" and vim.tbl_contains(auto_approve, tool_name) then + if tool_name and type(auto_approve) == "table" and vim.tbl_contains(auto_approve, tool_name) then return true end + + return false +end + +---@param description string +---@param callback fun(response: boolean, reason?: string) +---@param confirm_opts? avante.ui.ConfirmOptions +---@param session_ctx? table +---@param tool_name? string -- Optional tool name to check against tool_permissions config +---@return avante.ui.Confirm | nil +function M.confirm(description, callback, confirm_opts, session_ctx, tool_name) + callback = vim.schedule_wrap(callback) + + if M.is_auto_approved(session_ctx, tool_name) then callback(true) return end @@ -96,7 +101,7 @@ function M.confirm(message, callback, confirm_opts, session_ctx, tool_name) end confirm_opts = vim.tbl_deep_extend("force", { container_winid = sidebar.containers.input.winid }, confirm_opts or {}) if M.confirm_popup then M.confirm_popup:close() end - M.confirm_popup = Confirm:new(message, function(type, reason) + M.confirm_popup = Confirm:new(description, function(type, reason) if type == "yes" then callback(true) elseif type == "all" then diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index d1802ff79..aa534ceb1 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -3,11 +3,14 @@ local Helpers = require("avante.llm_tools.helpers") local Utils = require("avante.utils") local Highlights = require("avante.highlights") local Config = require("avante.config") - -local PRIORITY = (vim.hl or vim.highlight).priorities.user -local NAMESPACE = vim.api.nvim_create_namespace("avante-diff") -local KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("avante-diff-keybinding") - +local DiffDisplay = require("avante.utils.diff_display") + +--- LLM tool for applying targeted file changes using SEARCH/REPLACE blocks. +--- Processes streaming diffs, displays changes with highlighting, and handles +--- user confirmation before applying modifications. +--- +--- IMPORTANT: This tool is ONLY used by API-based providers (claude.lua, openai.lua, etc.). +--- ACP providers (gemini-cli, claude-code, etc) NEVER invoke this tool. ---@class AvanteLLMTool local M = setmetatable({}, Base) @@ -256,7 +259,7 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl patch = vim.diff(old_string, new_string, { ---@type integer[][] algorithm = "histogram", result_type = "indices", - ctxlen = vim.o.scrolloff, + ctxlen = 0, }) else patch = { { 1, #old_lines, 1, #new_lines } } @@ -317,24 +320,22 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl table.sort(diff_blocks, function(a, b) return a.start_line < b.start_line end) - local base_line = 0 - for _, diff_block in ipairs(diff_blocks) do - diff_block.new_start_line = diff_block.start_line + base_line - diff_block.new_end_line = diff_block.new_start_line + #diff_block.new_lines - 1 - base_line = base_line + #diff_block.new_lines - #diff_block.old_lines - end + -- Save initial hunk count to detect if user made manual accept/reject choices + local initial_hunk_count = #diff_blocks local function remove_diff_block(removed_idx, use_new_lines) local new_diff_blocks = {} local distance = 0 for idx, diff_block in ipairs(diff_blocks) do if idx == removed_idx then - if not use_new_lines then distance = #diff_block.old_lines - #diff_block.new_lines end + -- Virtual-first: accepting applies NEW lines (buffer changes), rejecting keeps OLD (no change) + if use_new_lines then distance = #diff_block.new_lines - #diff_block.old_lines end goto continue end if idx > removed_idx then - diff_block.new_start_line = diff_block.new_start_line + distance - diff_block.new_end_line = diff_block.new_end_line + distance + -- Adjust subsequent block positions based on actual buffer changes + diff_block.start_line = diff_block.start_line + distance + diff_block.end_line = diff_block.end_line + distance end table.insert(new_diff_blocks, diff_block) ::continue:: @@ -343,114 +344,34 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl diff_blocks = new_diff_blocks end - local function get_current_diff_block() - local winid = Utils.get_winid(bufnr) - local cursor_line = Utils.get_cursor_pos(winid) - for idx, diff_block in ipairs(diff_blocks) do - if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then - return diff_block, idx - end - end - return nil, nil - end - - local function get_prev_diff_block() - local winid = Utils.get_winid(bufnr) - local cursor_line = Utils.get_cursor_pos(winid) - local distance = nil - local idx = nil - for i, diff_block in ipairs(diff_blocks) do - if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then - local new_i = i - 1 - if new_i < 1 then return diff_blocks[#diff_blocks] end - return diff_blocks[new_i] - end - if diff_block.new_start_line < cursor_line then - local distance_ = cursor_line - diff_block.new_start_line - if distance == nil or distance_ < distance then - distance = distance_ - idx = i - end - end - end - if idx ~= nil then return diff_blocks[idx] end - if #diff_blocks > 0 then return diff_blocks[#diff_blocks] end - return nil - end - - local function get_next_diff_block() - local winid = Utils.get_winid(bufnr) - local cursor_line = Utils.get_cursor_pos(winid) - local distance = nil - local idx = nil - for i, diff_block in ipairs(diff_blocks) do - if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then - local new_i = i + 1 - if new_i > #diff_blocks then return diff_blocks[1] end - return diff_blocks[new_i] - end - if diff_block.new_start_line > cursor_line then - local distance_ = diff_block.new_start_line - cursor_line - if distance == nil or distance_ < distance then - distance = distance_ - idx = i - end - end - end - if idx ~= nil then return diff_blocks[idx] end - if #diff_blocks > 0 then return diff_blocks[1] end - return nil - end - - local show_keybinding_hint_extmark_id = nil - local augroup = vim.api.nvim_create_augroup("avante_replace_in_file", { clear = true }) - local function register_cursor_move_events() - local function show_keybinding_hint(lnum) - if show_keybinding_hint_extmark_id then - vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) - end - - local hint = string.format( - "[<%s>: OURS, <%s>: THEIRS, <%s>: PREV, <%s>: NEXT]", - Config.mappings.diff.ours, - Config.mappings.diff.theirs, - Config.mappings.diff.prev, - Config.mappings.diff.next - ) - - show_keybinding_hint_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, { - hl_group = "AvanteInlineHint", - virt_text = { { hint, "AvanteInlineHint" } }, - virt_text_pos = "right_align", - priority = PRIORITY, - }) - end + --- @type avante.ui.Confirm|nil + local confirm + local has_rejected = false - vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI", "WinLeave" }, { - buffer = bufnr, - group = augroup, - callback = function(event) - local diff_block = get_current_diff_block() - if (event.event == "CursorMoved" or event.event == "CursorMovedI") and diff_block then - show_keybinding_hint(diff_block.new_start_line) - else - vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1) - end - end, - }) + if not vim.api.nvim_buf_is_valid(bufnr) then + on_complete(false, "Buffer is no longer valid") + return end - local confirm - local has_rejected = false + local diff_display = DiffDisplay.new({ + bufnr = bufnr, + diff_blocks = diff_blocks, + }) local function register_buf_write_events() + local write_augroup = vim.api.nvim_create_augroup("avante_replace_in_file_write", { clear = true }) + vim.api.nvim_create_autocmd({ "BufWritePost" }, { buffer = bufnr, - group = augroup, + group = write_augroup, callback = function() + pcall(vim.api.nvim_del_augroup_by_id, write_augroup) + diff_display:clear() + if #diff_blocks ~= 0 then return end - pcall(vim.api.nvim_del_augroup_by_id, augroup) + if confirm then confirm:close() end + if has_rejected then on_complete(false, "User canceled") return @@ -461,138 +382,24 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl }) end - local function register_keybinding_events() - local keymap_opts = { buffer = bufnr } - vim.keymap.set({ "n", "v" }, Config.mappings.diff.ours, function() - if show_keybinding_hint_extmark_id then - vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) - end - local diff_block, idx = get_current_diff_block() - if not diff_block then return end - pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id) - pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id) - vim.api.nvim_buf_set_lines( - bufnr, - diff_block.new_start_line - 1, - diff_block.new_end_line, - false, - diff_block.old_lines - ) - diff_block.incoming_extmark_id = nil - diff_block.delete_extmark_id = nil - remove_diff_block(idx, false) - local next_diff_block = get_next_diff_block() - if next_diff_block then - local winnr = Utils.get_winid(bufnr) - vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 }) - vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) - end - has_rejected = true - end, keymap_opts) - - vim.keymap.set({ "n", "v" }, Config.mappings.diff.theirs, function() - if show_keybinding_hint_extmark_id then - vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) - end - local diff_block, idx = get_current_diff_block() - if not diff_block then return end - pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id) - pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id) - diff_block.incoming_extmark_id = nil - diff_block.delete_extmark_id = nil - remove_diff_block(idx, true) - local next_diff_block = get_next_diff_block() - if next_diff_block then - local winnr = Utils.get_winid(bufnr) - vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 }) - vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) - end - end, keymap_opts) - - vim.keymap.set({ "n", "v" }, Config.mappings.diff.next, function() - if show_keybinding_hint_extmark_id then - vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) - end - local diff_block = get_next_diff_block() - if not diff_block then return end - local winnr = Utils.get_winid(bufnr) - vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 }) - vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) - end, keymap_opts) + local function on_reject_diff_block(idx) + remove_diff_block(idx, false) + has_rejected = true - vim.keymap.set({ "n", "v" }, Config.mappings.diff.prev, function() - if show_keybinding_hint_extmark_id then - vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id) - end - local diff_block = get_prev_diff_block() - if not diff_block then return end - local winnr = Utils.get_winid(bufnr) - vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 }) - vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) - end, keymap_opts) - end - - local function unregister_keybinding_events() - pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.ours) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.theirs) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.next) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.prev) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.ours) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.theirs) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.next) - pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.prev) - end - - local function clear() - if bufnr and not vim.api.nvim_buf_is_valid(bufnr) then return end - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) - vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1) - unregister_keybinding_events() - pcall(vim.api.nvim_del_augroup_by_id, augroup) - end - - local function insert_diff_blocks_new_lines() - local base_line_ = 0 - for _, diff_block in ipairs(diff_blocks) do - local start_line = diff_block.start_line + base_line_ - local end_line = diff_block.end_line + base_line_ - base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines - vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, diff_block.new_lines) + if #diff_blocks == 0 and confirm and confirm.cancel then + confirm:cancel("All suggestions processed with rejections") end end - local function highlight_diff_blocks() - local line_count = vim.api.nvim_buf_line_count(bufnr) - vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1) - local base_line_ = 0 - local max_col = vim.o.columns - for _, diff_block in ipairs(diff_blocks) do - local start_line = diff_block.start_line + base_line_ - base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines - local deleted_virt_lines = vim - .iter(diff_block.old_lines) - :map(function(line) - --- append spaces to the end of the line - local line_ = line .. string.rep(" ", max_col - #line) - return { { line_, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } } - end) - :totable() - local end_row = start_line + #diff_block.new_lines - 1 - local delete_extmark_id = - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, math.min(math.max(end_row - 1, 0), line_count - 1), 0, { - virt_lines = deleted_virt_lines, - hl_eol = true, - hl_mode = "combine", - }) - local incoming_extmark_id = - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, math.min(math.max(start_line - 1, 0), line_count - 1), 0, { - hl_group = Highlights.INCOMING, - hl_eol = true, - hl_mode = "combine", - end_row = end_row, - }) - diff_block.delete_extmark_id = delete_extmark_id - diff_block.incoming_extmark_id = incoming_extmark_id + local function on_accept_diff_block(idx) + remove_diff_block(idx, true) + + if #diff_blocks == 0 and confirm then + if has_rejected then + confirm:cancel("All suggestions processed with mixed accept/reject") + else + confirm:confirm("All suggestions accepted") + end end end @@ -643,6 +450,7 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl return new_diff_blocks end + -- API providers can stream multiple times with partial diffs, so we need to highlight them as they come local function highlight_streaming_diff_blocks() local unstable_diff_blocks = get_unstable_diff_blocks(diff_blocks) session_ctx.prev_diff_blocks_map[opts.tool_use_id] = diff_blocks @@ -651,7 +459,7 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl local new_lines = diff_block.new_lines local start_line = diff_block.start_line if #diff_block.old_lines > 0 then - vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, { + vim.api.nvim_buf_set_extmark(bufnr, DiffDisplay.NAMESPACE, start_line - 1, 0, { hl_group = Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH, hl_eol = true, hl_mode = "combine", @@ -673,10 +481,9 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl else extmark_line = math.max(0, start_line - 1 + #diff_block.old_lines) end - -- Utils.debug("extmark_line", extmark_line, "idx", idx, "start_line", diff_block.start_line, "old_lines", table.concat(diff_block.old_lines, "\n")) local old_extmark_id = extmark_id_map[start_line] - if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, NAMESPACE, old_extmark_id) end - local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, { + if old_extmark_id then vim.api.nvim_buf_del_extmark(bufnr, DiffDisplay.NAMESPACE, old_extmark_id) end + local extmark_id = vim.api.nvim_buf_set_extmark(bufnr, DiffDisplay.NAMESPACE, extmark_line, 0, { virt_lines = virt_lines, hl_eol = true, hl_mode = "combine", @@ -687,10 +494,11 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl end if not is_streaming then - insert_diff_blocks_new_lines() - highlight_diff_blocks() - register_cursor_move_events() - register_keybinding_events() + diff_display:highlight() + diff_display:scroll_to_first_diff() + diff_display:register_cursor_move_events() + diff_display:register_navigation_keybindings() + diff_display:register_accept_reject_keybindings(on_accept_diff_block, on_reject_diff_block) register_buf_write_events() else highlight_streaming_diff_blocks() @@ -707,7 +515,7 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) else -- In normal mode, focus on the first diff block - vim.api.nvim_win_set_cursor(winnr, { math.min(diff_blocks[1].new_start_line, line_count), 0 }) + vim.api.nvim_win_set_cursor(winnr, { math.min(diff_blocks[1].start_line, line_count), 0 }) vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) end end @@ -720,20 +528,50 @@ Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE bl pcall(vim.cmd.undojoin) confirm = Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason) - clear() + diff_display:clear() + + if not vim.api.nvim_buf_is_valid(bufnr) then + on_complete(false, "Code buffer is not valid") + return + end + if not ok then - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, original_lines) on_complete(false, "User declined, reason: " .. (reason or "unknown")) return end + + -- Handle user approval - check if user made manual accept/reject choices + local remaining_hunks = #diff_blocks + + -- If user made NO manual choices (counts match), apply all remaining hunks + if remaining_hunks == initial_hunk_count and remaining_hunks > 0 then + local offset = 0 + for _, diff_block in ipairs(diff_blocks) do + local adjusted_start = diff_block.start_line - 1 + offset + local adjusted_end = diff_block.end_line + offset + + local ok_apply = + pcall(vim.api.nvim_buf_set_lines, bufnr, adjusted_start, adjusted_end, false, diff_block.new_lines) + + if ok_apply then + offset = offset + (#diff_block.new_lines - #diff_block.old_lines) + else + on_complete(false, "Failed to apply changes to buffer") + return + end + end + end + -- If user made manual choices (counts differ), remaining hunks are intentionally skipped + -- The buffer already has the accepted changes from co/ct operations + local parent_dir = vim.fn.fnamemodify(abs_path, ":h") + --- check if the parent dir is exists, if not, create it if vim.fn.isdirectory(parent_dir) == 0 then vim.fn.mkdir(parent_dir, "p") end - if not vim.api.nvim_buf_is_valid(bufnr) then - on_complete(false, "Code buffer is not valid") - return - end - vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write!") end) + + -- Write the file with current buffer state + vim.api.nvim_buf_call(bufnr, function() vim.cmd("silent noautocmd write!") end) + if session_ctx then Helpers.mark_as_not_viewed(input.path, session_ctx) end on_complete(true, nil) end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx, M.name) diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 4c684996b..2676e061c 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -78,8 +78,15 @@ vim.g.avante_login = vim.g.avante_login ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_state_change? fun(state: avante.GenerateState): nil ---@field update_tokens_usage? fun(usage: avante.LLMTokenUsage): nil ---- ----@alias AvanteLLMMessageContentItem string | { type: "text", text: string, cache_control: { type: string } | nil } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean, is_user_declined?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } + +---@alias AvanteLLMMessageContentItem +---| string +---| { type: "text", text: string, cache_control: { type: string } | nil } +---| { type: "image", source: { type: "base64", media_type: string, data: string } } +---| { type: "tool_use", name: string, id: string, input: any } +---| { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean, is_user_declined?: boolean } +---| { type: "thinking", thinking: string, signature: string } +---| { type: "redacted_thinking", data: string } ---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string @@ -93,6 +100,16 @@ vim.g.avante_login = vim.g.avante_login ---@field status "todo" | "doing" | "done" | "cancelled" ---@field priority "low" | "medium" | "high" +---@class avante.DiffBlock +---@field start_line integer Original position in file +---@field end_line integer Original position in file +---@field new_start_line? integer Adjusted position after previous diffs (populated after construction) +---@field new_end_line? integer Adjusted position after previous diffs (populated after construction) +---@field old_lines string[] Content to be replaced +---@field new_lines string[] New content +---@field delete_extmark_id? integer Extmark ID for deleted lines display +---@field incoming_extmark_id? integer Extmark ID for incoming lines display + ---@class avante.HistoryMessage ---@field message AvanteLLMMessage ---@field timestamp string @@ -117,7 +134,7 @@ vim.g.avante_login = vim.g.avante_login ---@field turn_id string | nil ---@field is_calling boolean | nil ---@field original_content AvanteLLMMessageContent | nil ----@field acp_tool_call? avante.acp.ToolCall +---@field acp_tool_call? avante.acp.ToolCallUpdate ---@class AvanteLLMToolResult ---@field tool_name string diff --git a/lua/avante/ui/confirm.lua b/lua/avante/ui/confirm.lua index edd45de73..fe76b673b 100644 --- a/lua/avante/ui/confirm.lua +++ b/lua/avante/ui/confirm.lua @@ -367,9 +367,16 @@ end function M:unbind_window_focus_keymaps() pcall(vim.keymap.del, { "n", "i" }, Config.mappings.confirm.focus_window) end -function M:cancel() - self.callback("no", "cancel") - return self:close() +---@param reason? string Optional reason for cancellation +function M:cancel(reason) + self:close() + self.callback("no", reason or "cancel") +end + +---@param reason? string Optional reason for confirmation +function M:confirm(reason) + self:close() + self.callback("yes", reason or "confirm") end function M:close() diff --git a/lua/avante/utils/diff_display.lua b/lua/avante/utils/diff_display.lua new file mode 100644 index 000000000..d4eff1781 --- /dev/null +++ b/lua/avante/utils/diff_display.lua @@ -0,0 +1,563 @@ +---@class avante.utils.diff_display +local M = {} + +local Utils = require("avante.utils") +local Highlights = require("avante.highlights") +local Config = require("avante.config") + +M.NAMESPACE = vim.api.nvim_create_namespace("avante-diff-display") +M.KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("avante-diff-keybinding") + +---Find character-level changes between two lines +---@param old_line string +---@param new_line string +---@return {old_start: integer, old_end: integer, new_start: integer, new_end: integer}|nil +local function find_inline_change(old_line, new_line) + if old_line == new_line then return nil end + + -- Find common prefix + local prefix_len = 0 + local min_len = math.min(#old_line, #new_line) + for i = 1, min_len do + if old_line:sub(i, i) == new_line:sub(i, i) then + prefix_len = i + else + break + end + end + + -- Find common suffix (after the prefix) + local suffix_len = 0 + for i = 1, min_len - prefix_len do + if old_line:sub(#old_line - i + 1, #old_line - i + 1) == new_line:sub(#new_line - i + 1, #new_line - i + 1) then + suffix_len = i + else + break + end + end + + -- Calculate change regions + local old_start = prefix_len + local old_end = #old_line - suffix_len + local new_start = prefix_len + local new_end = #new_line - suffix_len + + -- If no changes found, return nil + if old_start >= old_end and new_start >= new_end then return nil end + + return { + old_start = old_start, + old_end = old_end, + new_start = new_start, + new_end = new_end, + } +end + +---@class avante.DiffDisplayInstance +---@field bufnr integer Buffer number +---@field diff_blocks avante.DiffBlock[] List of diff blocks (mutable reference) +---@field augroup integer Autocommand group ID +---@field show_keybinding_hint_extmark_id integer? Current keybinding hint extmark ID +---@field has_accept_reject_keybindings boolean Whether accept/reject keybindings are registered +local DiffDisplayInstance = {} +DiffDisplayInstance.__index = DiffDisplayInstance + +---Create a new diff display instance +---@param opts { bufnr: integer, diff_blocks: avante.DiffBlock[] } +---@return avante.DiffDisplayInstance +function M.new(opts) + local augroup = vim.api.nvim_create_augroup("avante-diff-display-" .. opts.bufnr, { clear = true }) + local instance = setmetatable({ + bufnr = opts.bufnr, + diff_blocks = opts.diff_blocks, + augroup = augroup, + show_keybinding_hint_extmark_id = nil, + has_accept_reject_keybindings = false, + }, DiffDisplayInstance) + + vim.api.nvim_create_autocmd({ "BufDelete", "BufWipeout" }, { + buffer = opts.bufnr, + group = augroup, + once = true, + callback = function() instance:clear() end, + }) + + return instance +end + +---Get the current diff block under cursor +---@return avante.DiffBlock?, integer? The diff block and its index, or nil if not found +function DiffDisplayInstance:get_current_diff_block() + local winid = Utils.get_winid(self.bufnr) + if not winid then return nil, nil end + + local cursor_line = Utils.get_cursor_pos(winid) + + for idx, diff_block in ipairs(self.diff_blocks) do + if cursor_line >= diff_block.start_line and cursor_line <= diff_block.end_line then return diff_block, idx end + end + return nil, nil +end + +---Get the previous diff block +---@return avante.DiffBlock? The previous diff block, or nil if not found +function DiffDisplayInstance:get_prev_diff_block() + local winid = Utils.get_winid(self.bufnr) + + if not winid then return nil end + + local cursor_line = Utils.get_cursor_pos(winid) + local distance = nil + local idx = nil + for i, diff_block in ipairs(self.diff_blocks) do + if cursor_line >= diff_block.start_line and cursor_line <= diff_block.end_line then + local new_i = i - 1 + if new_i < 1 then return self.diff_blocks[#self.diff_blocks] end + return self.diff_blocks[new_i] + end + if diff_block.start_line < cursor_line then + local distance_ = cursor_line - diff_block.start_line + if distance == nil or distance_ < distance then + distance = distance_ + idx = i + end + end + end + if idx ~= nil then return self.diff_blocks[idx] end + if #self.diff_blocks > 0 then return self.diff_blocks[#self.diff_blocks] end + return nil +end + +---Get the next diff block +---@return avante.DiffBlock? The next diff block, or nil if not found +function DiffDisplayInstance:get_next_diff_block() + local winid = Utils.get_winid(self.bufnr) + + if not winid then return nil end + + local cursor_line = Utils.get_cursor_pos(winid) + local distance = nil + local idx = nil + for i, diff_block in ipairs(self.diff_blocks) do + if cursor_line >= diff_block.start_line and cursor_line <= diff_block.end_line then + local new_i = i + 1 + if new_i > #self.diff_blocks then return self.diff_blocks[1] end + return self.diff_blocks[new_i] + end + if diff_block.start_line > cursor_line then + local distance_ = diff_block.start_line - cursor_line + if distance == nil or distance_ < distance then + distance = distance_ + idx = i + end + end + end + if idx ~= nil then return self.diff_blocks[idx] end + if #self.diff_blocks > 0 then return self.diff_blocks[1] end + return nil +end + +---@param on_complete? function Optional callback to run after scroll completes +function DiffDisplayInstance:scroll_to_first_diff(on_complete) + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + if #self.diff_blocks == 0 then return end + + local first_diff = self.diff_blocks[1] + local bufnr = self.bufnr + + -- Schedule the scroll to happen after the UI settles and confirmation dialog is shown + vim.schedule(function() + if not vim.api.nvim_buf_is_valid(bufnr) then return end + + local winnr = Utils.get_winid(bufnr) + + -- If buffer is not visible in any window, open it in a suitable window + if not winnr then + local sidebar = require("avante").get() + local target_winid = nil + + -- Try to find a code window (non-sidebar window) + if + sidebar + and sidebar.code.winid + and sidebar.code.winid ~= 0 + and vim.api.nvim_win_is_valid(sidebar.code.winid) + then + target_winid = sidebar.code.winid + else + -- Find first non-sidebar window in the current tab + local all_wins = vim.api.nvim_tabpage_list_wins(0) + for _, winid in ipairs(all_wins) do + if vim.api.nvim_win_is_valid(winid) and (not sidebar or not sidebar:is_sidebar_winid(winid)) then + target_winid = winid + break + end + end + end + + -- If we found a suitable window, open the buffer in it + if target_winid then + pcall(vim.api.nvim_win_set_buf, target_winid, bufnr) + winnr = target_winid + else + return + end + end + + if not winnr then return end + + local line_count = vim.api.nvim_buf_line_count(bufnr) + local target_line = math.min(first_diff.start_line, line_count) + local current_win = vim.api.nvim_get_current_win() + + -- Respect auto_focus_on_diff_view config when deciding whether to switch windows + local should_switch_window = Config.behaviour.auto_focus_on_diff_view and winnr ~= current_win + + if should_switch_window then pcall(vim.api.nvim_set_current_win, winnr) end + + pcall(vim.api.nvim_win_set_cursor, winnr, { target_line, 0 }) + pcall(vim.api.nvim_win_call, winnr, function() vim.cmd("normal! zz") end) + + -- If auto_focus_on_diff_view is true, stay in the code window + -- Otherwise, return to the original window + if should_switch_window and not Config.behaviour.auto_focus_on_diff_view then + vim.schedule(function() + if vim.api.nvim_win_is_valid(current_win) then pcall(vim.api.nvim_set_current_win, current_win) end + end) + end + + -- Call completion callback if provided + if on_complete and type(on_complete) == "function" then vim.schedule(function() pcall(on_complete) end) end + end) +end + +function DiffDisplayInstance:highlight() + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + + local line_count = vim.api.nvim_buf_line_count(self.bufnr) + vim.api.nvim_buf_clear_namespace(self.bufnr, M.NAMESPACE, 0, -1) + local max_col = vim.o.columns + + for _, diff_block in ipairs(self.diff_blocks) do + -- Use original positions directly (no offset calculation needed since buffer unchanged) + local start_line = diff_block.start_line + local end_line = diff_block.end_line + + local is_modification = #diff_block.old_lines == #diff_block.new_lines and #diff_block.old_lines > 0 + + -- Highlight OLD content in buffer with background for DELETED lines + if #diff_block.old_lines > 0 then + -- end_row is 0-indexed and exclusive, so use end_line directly + local end_row = math.min(end_line, line_count) + + local ok_deleted_bg, deleted_bg_extmark_id = pcall( + vim.api.nvim_buf_set_extmark, + self.bufnr, + M.NAMESPACE, + math.min(math.max(start_line - 1, 0), line_count - 1), + 0, + { + hl_group = Highlights.DIFF_DELETED, + hl_eol = true, + hl_mode = "combine", + end_row = end_row, + priority = 100, + } + ) + + if ok_deleted_bg then diff_block.delete_extmark_id = deleted_bg_extmark_id end + + -- Word-level highlighting on OLD content in buffer + if is_modification then + for i, old_line in ipairs(diff_block.old_lines) do + local new_line = diff_block.new_lines[i] + local ok_change, change = pcall(find_inline_change, old_line, new_line) + if ok_change and change then + local line_nr = start_line - 1 + (i - 1) + + if change.old_end > change.old_start then + pcall(vim.api.nvim_buf_set_extmark, self.bufnr, M.NAMESPACE, line_nr, change.old_start, { + hl_group = Highlights.DIFF_DELETED_WORD, + end_col = change.old_end, + priority = 200, + }) + end + end + end + end + end + + -- Build virtual lines for NEW content (incoming changes) + if #diff_block.new_lines > 0 then + local incoming_virt_lines = {} + for i, new_line in ipairs(diff_block.new_lines) do + if is_modification then + local old_line = diff_block.old_lines[i] + local ok_change, change = pcall(find_inline_change, old_line, new_line) + + if ok_change and change and change.new_end > change.new_start then + local virt_line = {} + if change.new_start > 0 then + table.insert(virt_line, { new_line:sub(1, change.new_start), Highlights.DIFF_INCOMING }) + end + table.insert( + virt_line, + { new_line:sub(change.new_start + 1, change.new_end), Highlights.DIFF_INCOMING_WORD } + ) + + if change.new_end < #new_line then + table.insert(virt_line, { new_line:sub(change.new_end + 1), Highlights.DIFF_INCOMING }) + end + + local line_len = #new_line + if line_len < max_col and max_col > 0 then + table.insert(virt_line, { string.rep(" ", max_col - line_len), Highlights.DIFF_INCOMING }) + end + table.insert(incoming_virt_lines, virt_line) + else + -- No inline changes, use full line background + local line_ = new_line .. string.rep(" ", max_col - #new_line) + table.insert(incoming_virt_lines, { { line_, Highlights.DIFF_INCOMING } }) + end + else + -- Pure addition - use full line background + local line_ = new_line .. string.rep(" ", max_col - #new_line) + table.insert(incoming_virt_lines, { { line_, Highlights.DIFF_INCOMING } }) + end + end + + -- Place virtual lines below old content + local extmark_line = math.min(math.max(end_line - 1, 0), line_count - 1) + + local ok_incoming_virt, incoming_virt_extmark_id = + pcall(vim.api.nvim_buf_set_extmark, self.bufnr, M.NAMESPACE, extmark_line, 0, { + virt_lines = incoming_virt_lines, + virt_lines_above = false, + hl_eol = true, + hl_mode = "combine", + }) + + if ok_incoming_virt then diff_block.incoming_extmark_id = incoming_virt_extmark_id end + end + end +end + +function DiffDisplayInstance:register_navigation_keybindings() + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + + local keymap_opts = { buffer = self.bufnr } + + vim.keymap.set({ "n", "v" }, Config.mappings.diff.next, function() + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + if self.show_keybinding_hint_extmark_id then + pcall(vim.api.nvim_buf_del_extmark, self.bufnr, M.KEYBINDING_NAMESPACE, self.show_keybinding_hint_extmark_id) + self.show_keybinding_hint_extmark_id = nil + end + local diff_block = self:get_next_diff_block() + if not diff_block then return end + local winnr = Utils.get_winid(self.bufnr) + + if not winnr then return end + + local line_count = vim.api.nvim_buf_line_count(self.bufnr) + local target_line = math.min(diff_block.start_line, line_count) + vim.api.nvim_win_set_cursor(winnr, { target_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end, keymap_opts) + + vim.keymap.set({ "n", "v" }, Config.mappings.diff.prev, function() + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + if self.show_keybinding_hint_extmark_id then + pcall(vim.api.nvim_buf_del_extmark, self.bufnr, M.KEYBINDING_NAMESPACE, self.show_keybinding_hint_extmark_id) + self.show_keybinding_hint_extmark_id = nil + end + local diff_block = self:get_prev_diff_block() + if not diff_block then return end + local winnr = Utils.get_winid(self.bufnr) + + if not winnr then return end + + local line_count = vim.api.nvim_buf_line_count(self.bufnr) + local target_line = math.min(diff_block.start_line, line_count) + vim.api.nvim_win_set_cursor(winnr, { target_line, 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end, keymap_opts) +end + +---@param on_accept function(idx: integer) Callback when user accepts a hunk +---@param on_reject function(idx: integer) Callback when user rejects a hunk +function DiffDisplayInstance:register_accept_reject_keybindings(on_accept, on_reject) + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + + self.has_accept_reject_keybindings = true + local keymap_opts = { buffer = self.bufnr } + + -- "co" - Choose OURS (reject incoming changes, keep original) + vim.keymap.set({ "n", "v" }, Config.mappings.diff.ours, function() + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + local diff_block, idx = self:get_current_diff_block() + if not diff_block then return end + + -- Clear all extmarks in this diff block's range (background, virtual text, word highlights, and hints) + local line_count = vim.api.nvim_buf_line_count(self.bufnr) + local clear_start = math.max(0, diff_block.start_line - 1) + local clear_end = math.min(line_count, diff_block.end_line + 1) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.NAMESPACE, clear_start, clear_end) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.KEYBINDING_NAMESPACE, clear_start, clear_end) + + -- Clear the stored hint ID + self.show_keybinding_hint_extmark_id = nil + + diff_block.incoming_extmark_id = nil + diff_block.delete_extmark_id = nil + + -- Remove the diff block from the list so it's no longer navigable + table.remove(self.diff_blocks, idx) + + if on_reject then on_reject(idx) end + + -- Navigate to next diff block (if any) + local next_diff_block = self:get_next_diff_block() + if next_diff_block then + local winnr = Utils.get_winid(self.bufnr) + if not winnr then return end + + vim.api.nvim_win_set_cursor(winnr, { math.min(next_diff_block.start_line, line_count), 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end + end, keymap_opts) + + -- "ct" - Choose THEIRS (accept incoming changes) + vim.keymap.set({ "n", "v" }, Config.mappings.diff.theirs, function() + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + local diff_block, idx = self:get_current_diff_block() + if not diff_block then return end + + local ok = pcall( + vim.api.nvim_buf_set_lines, + self.bufnr, + diff_block.start_line - 1, + diff_block.end_line, + false, + diff_block.new_lines + ) + + if not ok then + Utils.error("Failed to apply changes to buffer") + return + end + + -- Clear all extmarks in this diff block's range (background, virtual text, word highlights, and hints) + local line_count = vim.api.nvim_buf_line_count(self.bufnr) + local clear_start = math.max(0, diff_block.start_line - 1) + local clear_end = math.min(line_count, diff_block.start_line - 1 + #diff_block.new_lines + 1) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.NAMESPACE, clear_start, clear_end) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.KEYBINDING_NAMESPACE, clear_start, clear_end) + + self.show_keybinding_hint_extmark_id = nil + + diff_block.incoming_extmark_id = nil + diff_block.delete_extmark_id = nil + + -- Remove the diff block from the list so it's no longer navigable + table.remove(self.diff_blocks, idx) + + if on_accept then on_accept(idx) end + + -- Navigate to next diff block (if any) + local next_diff_block = self:get_next_diff_block() + if next_diff_block then + local winnr = Utils.get_winid(self.bufnr) + if not winnr then return end + + vim.api.nvim_win_set_cursor(winnr, { math.min(next_diff_block.start_line, line_count), 0 }) + vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end) + end + end, keymap_opts) +end + +function DiffDisplayInstance:register_cursor_move_events() + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + + local function show_keybinding_hint(lnum) + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + if self.show_keybinding_hint_extmark_id then + pcall(vim.api.nvim_buf_del_extmark, self.bufnr, M.KEYBINDING_NAMESPACE, self.show_keybinding_hint_extmark_id) + self.show_keybinding_hint_extmark_id = nil + end + + -- Show different hints based on whether accept/reject keybindings are registered + -- API providers: show OURS/THEIRS/PREV/NEXT (full partial accept/reject support) + -- ACP providers: show PREV/NEXT only (navigation only, no partial accept/reject) + local hint + if not self.has_accept_reject_keybindings then + hint = string.format("[<%s>: PREV, <%s>: NEXT]", Config.mappings.diff.prev, Config.mappings.diff.next) + else + hint = string.format( + "[<%s>: OURS, <%s>: THEIRS, <%s>: PREV, <%s>: NEXT]", + Config.mappings.diff.ours, + Config.mappings.diff.theirs, + Config.mappings.diff.prev, + Config.mappings.diff.next + ) + end + + self.show_keybinding_hint_extmark_id = + vim.api.nvim_buf_set_extmark(self.bufnr, M.KEYBINDING_NAMESPACE, lnum - 1, -1, { + hl_group = "AvanteInlineHint", + virt_text = { { hint, "AvanteInlineHint" } }, + virt_text_pos = "right_align", + priority = (vim.hl or vim.highlight).priorities.user, + }) + end + + vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI", "WinLeave" }, { + buffer = self.bufnr, + group = self.augroup, + callback = function(event) + if not vim.api.nvim_buf_is_valid(self.bufnr) then return end + local diff_block = self:get_current_diff_block() + if (event.event == "CursorMoved" or event.event == "CursorMovedI") and diff_block then + show_keybinding_hint(diff_block.start_line) + else + vim.api.nvim_buf_clear_namespace(self.bufnr, M.KEYBINDING_NAMESPACE, 0, -1) + end + end, + }) +end + +function DiffDisplayInstance:unregister_keybindings() + if not self.bufnr or not vim.api.nvim_buf_is_valid(self.bufnr) then return end + + -- We need to pcall each del separately to avoid stopping on first error, `del` errors if keymap doesn't exist + pcall(vim.keymap.del, "n", Config.mappings.diff.next, { buffer = self.bufnr }) + pcall(vim.keymap.del, "v", Config.mappings.diff.next, { buffer = self.bufnr }) + pcall(vim.keymap.del, "n", Config.mappings.diff.prev, { buffer = self.bufnr }) + pcall(vim.keymap.del, "v", Config.mappings.diff.prev, { buffer = self.bufnr }) + pcall(vim.keymap.del, "n", Config.mappings.diff.ours, { buffer = self.bufnr }) + pcall(vim.keymap.del, "v", Config.mappings.diff.ours, { buffer = self.bufnr }) + pcall(vim.keymap.del, "n", Config.mappings.diff.theirs, { buffer = self.bufnr }) + pcall(vim.keymap.del, "v", Config.mappings.diff.theirs, { buffer = self.bufnr }) +end + +function DiffDisplayInstance:clear() + self:unregister_keybindings() + + pcall(vim.api.nvim_del_augroup_by_id, self.augroup) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.NAMESPACE, 0, -1) + pcall(vim.api.nvim_buf_clear_namespace, self.bufnr, M.KEYBINDING_NAMESPACE, 0, -1) + + -- Clear extmark IDs from diff_blocks to help GC + for _, block in ipairs(self.diff_blocks or {}) do + block.incoming_extmark_id = nil + block.delete_extmark_id = nil + end + + -- Clear references to help GC + self.bufnr = nil + self.diff_blocks = nil + self.augroup = nil + self.show_keybinding_hint_extmark_id = nil +end + +return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index ee89d286e..e4422e1dc 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -660,6 +660,40 @@ function M.try_find_match(original_lines, target_lines, compare_fn) return start_line, end_line end +---@param original_lines string[] +---@param target_lines string[] +---@param compare_fn fun(line_a: string, line_b: string): boolean +---@return table[] matches Array of {start_line, end_line} pairs +function M.try_find_all_matches(original_lines, target_lines, compare_fn) + local matches = {} + + if #original_lines == 0 or #target_lines == 0 or #target_lines > #original_lines then return matches end + + local i = 1 + while i <= #original_lines - #target_lines + 1 do + local match = true + for j = 1, #target_lines do + local line_a = original_lines[i + j - 1] + local line_b = target_lines[j] + + if not line_a or not line_b or not compare_fn(line_a, line_b) then + match = false + break + end + end + if match then + local start_line = i + local end_line = i + #target_lines - 1 + table.insert(matches, { start_line = start_line, end_line = end_line }) + -- Skip past this match to avoid overlapping + i = end_line + 1 + else + i = i + 1 + end + end + return matches +end + ---@param original_lines string[] ---@param target_lines string[] ---@return integer | nil start_line @@ -703,6 +737,52 @@ function M.fuzzy_match(original_lines, target_lines) return start_line, end_line end +---@param original_lines string[] +---@param target_lines string[] +---@return table[] matches Array of {start_line, end_line} pairs, empty if no matches +function M.find_all_matches(original_lines, target_lines) + -- Try exact match first + local matches = M.try_find_all_matches( + original_lines, + target_lines, + function(line_a, line_b) return line_a == line_b end + ) + if #matches > 0 then return matches end + + -- Try fuzzy match (trim trailing spaces/tabs) + matches = M.try_find_all_matches( + original_lines, + target_lines, + function(line_a, line_b) return M.trim(line_a, { suffix = " \t" }) == M.trim(line_b, { suffix = " \t" }) end + ) + if #matches > 0 then return matches end + + -- Try trim_space match + matches = M.try_find_all_matches( + original_lines, + target_lines, + function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(line_b) end + ) + if #matches > 0 then return matches end + + -- Try trim slashes match + matches = M.try_find_all_matches( + original_lines, + target_lines, + function(line_a, line_b) return line_a == M.trim_escapes(line_b) end + ) + if #matches > 0 then return matches end + + -- Try trim slashes and trim_space match + matches = M.try_find_all_matches( + original_lines, + target_lines, + function(line_a, line_b) return M.trim_space(line_a) == M.trim_space(M.trim_escapes(line_b)) end + ) + + return matches +end + function M.relative_path(absolute) local project_root = M.get_project_root() return M.make_relative_path(absolute, project_root) diff --git a/tests/ui/acp_confirm_adapter_spec.lua b/tests/acp/acp_confirm_adapter_spec.lua similarity index 98% rename from tests/ui/acp_confirm_adapter_spec.lua rename to tests/acp/acp_confirm_adapter_spec.lua index 0fd2d0fa3..8334803bd 100644 --- a/tests/ui/acp_confirm_adapter_spec.lua +++ b/tests/acp/acp_confirm_adapter_spec.lua @@ -1,4 +1,4 @@ -local ACPConfirmAdapter = require("avante.ui.acp_confirm_adapter") +local ACPConfirmAdapter = require("avante.acp.acp_confirm_adapter") describe("ACPConfirmAdapter", function() describe("map_acp_options", function() diff --git a/tests/data/acp_diff/fixtures.lua b/tests/data/acp_diff/fixtures.lua new file mode 100644 index 000000000..91522b7f0 --- /dev/null +++ b/tests/data/acp_diff/fixtures.lua @@ -0,0 +1,263 @@ +---@class avante.test.acp_diff.fixtures +--- Anonymized ACP tool_call fixtures for testing acp_diff_handler +--- Based on real ACP session/update protocol messages + +local M = {} + +-- Simple single-line edit (most common case) +M.simple_single_line_edit = { + content = { + { + type = "diff", + path = "/project/README.md", + oldText = "# Platform Frontend", + newText = "# Platform Front-end", + }, + }, + rawInput = { + file_path = "/project/README.md", + old_string = "# Platform Frontend", + new_string = "# Platform Front-end", + }, + kind = "edit", + locations = { { path = "/project/README.md" } }, + status = "pending", + title = "Edit `/project/README.md`", + toolCallId = "test-tool-call-001", +} + +-- Replace all occurrences (replace_all = true) +M.replace_all_occurrences = { + content = { + { + type = "diff", + path = "/project/app.lua", + oldText = "config", + newText = "configuration", + }, + }, + rawInput = { + file_path = "/project/app.lua", + old_string = "config", + new_string = "configuration", + replace_all = true, + }, + kind = "edit", + locations = { { path = "/project/app.lua" } }, + status = "pending", + title = "Edit `/project/app.lua`", + toolCallId = "test-tool-call-002", +} + +-- CRITICAL BUG TEST: Special characters in replacement text +-- Tests Lua pattern special chars: %1, %2, etc. should be literal +M.special_chars_in_replacement = { + rawInput = { + file_path = "/project/lib.lua", + old_string = "variable", + new_string = "result%1", + replace_all = true, + }, + kind = "edit", + locations = { { path = "/project/lib.lua" } }, + status = "pending", + title = "Edit `/project/lib.lua`", + toolCallId = "test-tool-call-003", +} + +-- More special characters: backslashes, percent signs +M.special_chars_backslash = { + rawInput = { + file_path = "/project/paths.lua", + old_string = "path", + new_string = "C:\\Users\\path", + replace_all = false, + }, + kind = "edit", + toolCallId = "test-tool-call-004", +} + +-- Multiple content items for same file +M.multiple_edits_same_file = { + content = { + { + type = "diff", + path = "/project/config.lua", + oldText = "foo", + newText = "bar", + }, + { + type = "diff", + path = "/project/config.lua", + oldText = "baz", + newText = "qux", + }, + }, + kind = "edit", + locations = { { path = "/project/config.lua" } }, + status = "pending", + title = "Edit `/project/config.lua`", + toolCallId = "test-tool-call-005", +} + +-- New file creation (oldText is empty or nil) +M.new_file_creation_empty_string = { + content = { + { + type = "diff", + path = "/project/new_module.lua", + oldText = "", + newText = "local M = {}\n\nfunction M.init()\n return true\nend\n\nreturn M", + }, + }, + rawInput = { + file_path = "/project/new_module.lua", + old_string = "", + new_string = "local M = {}\n\nfunction M.init()\n return true\nend\n\nreturn M", + }, + kind = "edit", + locations = { { path = "/project/new_module.lua" } }, + status = "pending", + title = "Create `/project/new_module.lua`", + toolCallId = "test-tool-call-006", +} + +-- New file creation with vim.NIL +M.new_file_creation_vim_nil = { + content = { + { + type = "diff", + path = "/project/another_module.lua", + oldText = vim.NIL, + newText = "-- New file\nreturn {}", + }, + }, + rawInput = { + file_path = "/project/another_module.lua", + old_string = vim.NIL, + new_string = "-- New file\nreturn {}", + }, + kind = "edit", + locations = { { path = "/project/another_module.lua" } }, + status = "pending", + title = "Create `/project/another_module.lua`", + toolCallId = "test-tool-call-007", +} + +-- Multi-line replacement +M.multiline_function_edit = { + content = { + { + type = "diff", + path = "/project/utils.lua", + oldText = "function process(data)\n return data\nend", + newText = "function process(data)\n -- Add validation\n if not data then return nil end\n return data\nend", + }, + }, + rawInput = { + file_path = "/project/utils.lua", + old_string = "function process(data)\n return data\nend", + new_string = "function process(data)\n -- Add validation\n if not data then return nil end\n return data\nend", + }, + kind = "edit", + locations = { { path = "/project/utils.lua" } }, + status = "pending", + title = "Edit `/project/utils.lua`", + toolCallId = "test-tool-call-008", +} + +-- Multiple diff blocks in same file (for testing cumulative offset) +M.multiple_diff_blocks_offset_test = { + content = { + { + type = "diff", + path = "/project/main.lua", + oldText = "local a = 1", + newText = "local a = 1\nlocal b = 2", + }, + { + type = "diff", + path = "/project/main.lua", + oldText = "return result", + newText = "return result", + }, + }, + kind = "edit", + locations = { { path = "/project/main.lua" } }, + status = "pending", + title = "Edit `/project/main.lua`", + toolCallId = "test-tool-call-009", +} + +-- Edge case: Only rawInput present (no content array) +M.only_raw_input = { + rawInput = { + file_path = "/project/settings.lua", + old_string = "debug = false", + new_string = "debug = true", + replace_all = false, + }, + kind = "edit", + locations = { { path = "/project/settings.lua" } }, + status = "pending", + title = "Edit `/project/settings.lua`", + toolCallId = "test-tool-call-010", +} + +-- Edge case: Single-line file edit +M.single_line_file_edit = { + content = { + { + type = "diff", + path = "/project/.gitignore", + oldText = "node_modules", + newText = "node_modules\n.env", + }, + }, + rawInput = { + file_path = "/project/.gitignore", + old_string = "node_modules", + new_string = "node_modules\n.env", + }, + kind = "edit", + toolCallId = "test-tool-call-011", +} + +-- Edge case: Deletion (new_string is empty) +M.delete_lines = { + content = { + { + type = "diff", + path = "/project/temp.lua", + oldText = "-- TODO: Remove this\nlocal unused = 1", + newText = "", + }, + }, + rawInput = { + file_path = "/project/temp.lua", + old_string = "-- TODO: Remove this\nlocal unused = 1", + new_string = "", + }, + kind = "edit", + toolCallId = "test-tool-call-012", +} + +-- Edge case: Substring replacement within line (not full line) +M.substring_within_line = { + content = { { + type = "diff", + path = "/project/code.lua", + oldText = "old", + newText = "new", + } }, + rawInput = { + file_path = "/project/code.lua", + old_string = "old", + new_string = "new", + replace_all = false, + }, + kind = "edit", + toolCallId = "test-tool-call-013", +} + +return M diff --git a/tests/data/acp_diff/sample_files.lua b/tests/data/acp_diff/sample_files.lua new file mode 100644 index 000000000..a317a76e1 --- /dev/null +++ b/tests/data/acp_diff/sample_files.lua @@ -0,0 +1,130 @@ +---@class avante.test.acp_diff.sample_files +--- Mock file contents for testing acp_diff_handler +--- Each entry represents the current state of a file before edits + +local M = {} + +-- Simple README file with single line +M.readme_simple = { + "# Platform Frontend", + "", + "This is a test project.", +} + +-- File with multiple occurrences of 'config' +M.app_with_config = { + "local config = require('config')", + "local function setup()", + " config.init()", + " return config", + "end", +} + +-- File with 'variable' keyword for special char testing +M.lib_with_variable = { + "local variable = 'test'", + "local another_variable = 'value'", + "local variable_name = 'foo'", + "return variable", +} + +-- File with path keyword +M.paths_file = { + "local path = '/usr/local'", + "return path", +} + +-- Config file with foo and baz +M.config_with_foo_baz = { + "local M = {}", + "M.foo = 'original'", + "M.baz = 'original'", + "return M", +} + +-- Empty file (for new file creation tests) +M.empty_file = {} + +-- Utils file with function +M.utils_with_function = { + "local M = {}", + "", + "function process(data)", + " return data", + "end", + "", + "return M", +} + +-- Main file with multiple sections for offset testing +M.main_file_for_offset = { + "local a = 1", + "", + "local function work()", + " print('working')", + "end", + "", + "return result", +} + +-- Settings file +M.settings_file = { + "return {", + " debug = false,", + " verbose = true,", + "}", +} + +-- Single line gitignore +M.gitignore_single_line = { + "node_modules", +} + +-- Temp file with code to delete +M.temp_file_with_todo = { + "local M = {}", + "", + "-- TODO: Remove this", + "local unused = 1", + "", + "return M", +} + +-- File with 'old' substring within longer line +M.code_with_substring = { + "local old_value = 123", + "local very_old_code = true", + "return old_value", +} + +-- File with duplicate text on multiple lines +M.file_with_duplicates = { + "config = 1", + "local config = 2", + " config = 3", + "return config", +} + +-- Multi-line file for minimize_diff testing +M.file_for_minimize_diff = { + "line 1 - change me", + "line 2 - keep me", + "line 3 - change me", + "line 4 - keep me", + "line 5 - change me", +} + +-- File with special characters +M.file_with_special_chars = { + "local pattern = 'test%d+'", + "local regex = [[\\w+]]", + "return pattern", +} + +-- Large file for performance testing (optional) +M.large_file = {} +for i = 1, 100 do + table.insert(M.large_file, "line " .. i) +end + +return M diff --git a/tests/llm_tools/acp_diff_handler_spec.lua b/tests/llm_tools/acp_diff_handler_spec.lua new file mode 100644 index 000000000..549530658 --- /dev/null +++ b/tests/llm_tools/acp_diff_handler_spec.lua @@ -0,0 +1,439 @@ +---@diagnostic disable: undefined-field +local M = require("avante.acp.acp_diff_handler") +local Utils = require("avante.utils") +local Config = require("avante.config") +local fixtures = require("tests.data.acp_diff.fixtures") +local sample_files = require("tests.data.acp_diff.sample_files") +local stub = require("luassert.stub") + +describe("acp_diff_handler", function() + local original_behaviour + + before_each(function() + -- Initialize Config.behaviour if it doesn't exist + if not Config.behaviour then Config.behaviour = {} end + + -- Store original config value + original_behaviour = vim.deepcopy(Config.behaviour) + + -- Set minimize_diff to false for predictable tests + Config.behaviour.minimize_diff = false + end) + + after_each(function() + -- Restore original config + if original_behaviour then Config.behaviour = original_behaviour end + end) + + describe("has_diff_content", function() + it("should detect diff in content array", function() + local result = M.has_diff_content(fixtures.simple_single_line_edit) + assert.is_true(result) + end) + + it("should detect diff in rawInput with new_string", function() + local tool_call = { + rawInput = { + new_string = "text content", + }, + } + assert.is_true(M.has_diff_content(tool_call)) + end) + + it("should return false when no diff present", function() + local tool_call = {} + assert.is_false(M.has_diff_content(tool_call)) + end) + + it("should return false when rawInput.new_string is nil", function() + local tool_call = { + rawInput = { + new_string = nil, + }, + } + assert.is_false(M.has_diff_content(tool_call)) + end) + + it("should return false when rawInput.new_string is vim.NIL", function() + local tool_call = { + rawInput = { + new_string = vim.NIL, + }, + } + assert.is_false(M.has_diff_content(tool_call)) + end) + end) + + describe("extract_diff_blocks", function() + local path_stub, read_stub, fuzzy_stub + + before_each(function() + -- Default stubs that can be overridden in specific tests + path_stub = stub(Utils, "to_absolute_path", function(path) + return path -- Return as-is for testing + end) + end) + + after_each(function() + if path_stub then path_stub:revert() end + if read_stub then read_stub:revert() end + if fuzzy_stub then fuzzy_stub:revert() end + end) + + describe("simple single-line edits", function() + before_each(function() + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return sample_files.readme_simple, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function(file_lines, search_lines) + -- Find exact match + local search_str = search_lines[1] + for i, line in ipairs(file_lines) do + if line == search_str then return i, i + #search_lines - 1 end + end + return nil, nil + end) + end) + + it("should extract simple single-line replacement from content", function() + local result = M.extract_diff_blocks(fixtures.simple_single_line_edit) + + assert.is_not_nil(result["/project/README.md"]) + assert.equals(1, #result["/project/README.md"]) + + local block = result["/project/README.md"][1] + assert.equals(1, block.start_line) + assert.equals(1, block.end_line) + assert.same({ "# Platform Frontend" }, block.old_lines) + assert.same({ "# Platform Front-end" }, block.new_lines) + -- Note: new_start_line and new_end_line are not populated during extraction + -- (virtual-first approach - these fields are optional and set later if needed) + end) + end) + + describe("replace_all behavior", function() + before_each(function() + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return sample_files.app_with_config, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- Force fallback to substring search + end) + end) + + it("should replace all occurrences when replace_all is true", function() + local find_all_stub = stub(Utils, "find_all_matches", function(file_lines, search_lines) + local matches = {} + local search_str = search_lines[1] + for i, line in ipairs(file_lines) do + if line:find(search_str, 1, true) then table.insert(matches, { start_line = i, end_line = i }) end + end + return matches + end) + + local result = M.extract_diff_blocks(fixtures.replace_all_occurrences) + + assert.is_not_nil(result["/project/app.lua"]) + -- Should find 3 occurrences: lines 1, 3, 4 + assert.equals(3, #result["/project/app.lua"]) + + find_all_stub:revert() + end) + + it("should only replace first occurrence when replace_all is false", function() + read_stub:revert() + fuzzy_stub:revert() + + read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.file_with_duplicates, nil end + ) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- Force substring replacement + end) + + local tool_call = { + rawInput = { + file_path = "/project/app.lua", + old_string = "config", + new_string = "configuration", + replace_all = false, + }, + } + + local result = M.extract_diff_blocks(tool_call) + + assert.is_not_nil(result["/project/app.lua"]) + -- Should only find 1 occurrence (first match) + assert.equals(1, #result["/project/app.lua"]) + end) + end) + + describe("CRITICAL BUG: special characters in replacement", function() + before_each(function() + read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.lib_with_variable, nil end + ) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- Force substring replacement path + end) + end) + + it("should handle %1 in replacement text as literal (not backreference)", function() + local result = M.extract_diff_blocks(fixtures.special_chars_in_replacement) + + assert.is_not_nil(result["/project/lib.lua"]) + local blocks = result["/project/lib.lua"] + + -- Verify that at least one block was created + assert.truthy(#blocks > 0, "Expected at least one diff block") + + -- Verify that %1 appears literally in the result (escaped or literal) + local found_replacement = false + for _, block in ipairs(blocks) do + local new_text = table.concat(block.new_lines, "\n") + -- Should contain "result" and "%1" (possibly as "result%1") + if new_text:find("result", 1, true) and new_text:find("%%1", 1, false) then + found_replacement = true + break + end + end + assert.truthy(found_replacement, "Expected literal 'result%1' pattern in replacement") + end) + + it("should handle backslashes in replacement text", function() + read_stub:revert() + fuzzy_stub:revert() + + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return sample_files.paths_file, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- Force substring replacement + end) + + local result = M.extract_diff_blocks(fixtures.special_chars_backslash) + + -- Allow for case where no match is found (backslash handling is complex) + if result["/project/paths.lua"] and #result["/project/paths.lua"] > 0 then + local block = result["/project/paths.lua"][1] + local new_text = table.concat(block.new_lines, "\n") + -- Just verify we got some replacement + assert.truthy(#new_text > 0) + end + end) + end) + + describe("multiple content items for same file", function() + before_each(function() + read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.config_with_foo_baz, nil end + ) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- Force substring replacement + end) + end) + + it("should handle multiple edits to same file", function() + local result = M.extract_diff_blocks(fixtures.multiple_edits_same_file) + + assert.is_not_nil(result["/project/config.lua"]) + -- Should have 2 diff blocks + assert.equals(2, #result["/project/config.lua"]) + + -- Blocks should be sorted by start_line + local blocks = result["/project/config.lua"] + assert.truthy(blocks[1].start_line <= blocks[2].start_line) + end) + end) + + describe("new file creation", function() + before_each(function() + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() + return {}, nil -- Empty file + end) + fuzzy_stub = stub(Utils, "fuzzy_match", function() return nil, nil end) + end) + + it("should handle new file with empty string oldText", function() + local result = M.extract_diff_blocks(fixtures.new_file_creation_empty_string) + + assert.is_not_nil(result["/project/new_module.lua"]) + local block = result["/project/new_module.lua"][1] + + assert.equals(1, block.start_line) + assert.equals(0, block.end_line) -- New file marker + assert.same({}, block.old_lines) + -- The file content splits into 7 lines (including empty lines from \n\n) + assert.equals(7, #block.new_lines) + end) + + it("should handle new file with vim.NIL oldText", function() + local result = M.extract_diff_blocks(fixtures.new_file_creation_vim_nil) + + assert.is_not_nil(result["/project/another_module.lua"]) + local block = result["/project/another_module.lua"][1] + + assert.equals(1, block.start_line) + assert.equals(0, block.end_line) + assert.same({}, block.old_lines) + assert.equals(2, #block.new_lines) + end) + end) + + describe("multi-line replacements", function() + before_each(function() + read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.utils_with_function, nil end + ) + fuzzy_stub = stub(Utils, "fuzzy_match", function(file_lines, search_lines) + -- Find the function across multiple lines + if #search_lines == 3 and search_lines[1]:match("^function process") then + return 3, 5 -- Lines 3-5 in utils_with_function + end + return nil, nil + end) + end) + + it("should handle multi-line function replacement", function() + local result = M.extract_diff_blocks(fixtures.multiline_function_edit) + + assert.is_not_nil(result["/project/utils.lua"]) + local block = result["/project/utils.lua"][1] + + assert.equals(3, block.start_line) + assert.equals(5, block.end_line) + assert.equals(3, #block.old_lines) + assert.equals(5, #block.new_lines) -- Expanded to 5 lines + end) + end) + + -- Note: Cumulative offset calculation tests removed + -- The virtual-first approach no longer calculates new_start_line/new_end_line during extraction + -- These fields are optional and only populated later when diffs are applied to buffers + + describe("edge cases", function() + it("should handle empty file", function() + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return {}, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function() return nil, nil end) + + local tool_call = { + content = { + { + type = "diff", + path = "/project/empty.lua", + oldText = "", + newText = "content", + }, + }, + } + + local result = M.extract_diff_blocks(tool_call) + assert.is_not_nil(result["/project/empty.lua"]) + end) + + it("should handle only rawInput present (no content array)", function() + if read_stub then read_stub:revert() end + if fuzzy_stub then fuzzy_stub:revert() end + + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return sample_files.settings_file, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function(file_lines, search_lines) + -- Find "debug = false" in settings file + for i, line in ipairs(file_lines) do + if line:find(search_lines[1], 1, true) then return i, i end + end + return nil, nil + end) + + local result = M.extract_diff_blocks(fixtures.only_raw_input) + + assert.is_not_nil(result["/project/settings.lua"]) + assert.truthy(#result["/project/settings.lua"] > 0) + end) + + it("should handle deletion (newText is empty)", function() + read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.temp_file_with_todo, nil end + ) + fuzzy_stub = stub(Utils, "fuzzy_match", function(file_lines, search_lines) + -- Find lines 3-4 + if #search_lines == 2 and search_lines[1]:match("TODO") then return 3, 4 end + return nil, nil + end) + + local result = M.extract_diff_blocks(fixtures.delete_lines) + + assert.is_not_nil(result["/project/temp.lua"]) + local block = result["/project/temp.lua"][1] + + assert.equals(3, block.start_line) + assert.equals(4, block.end_line) + assert.same({}, block.new_lines) + -- Note: new_start_line/new_end_line not populated during extraction (virtual-first approach) + end) + + it("should return empty table when no diff found", function() + read_stub = stub(Utils, "read_file_from_buf_or_disk", function() return { "unrelated content" }, nil end) + fuzzy_stub = stub(Utils, "fuzzy_match", function() + return nil, nil -- No match + end) + + local tool_call = { + content = { + { + type = "diff", + path = "/project/file.lua", + oldText = "nonexistent", + newText = "replacement", + }, + }, + } + + local result = M.extract_diff_blocks(tool_call) + + -- Should return empty table when no matches found + assert.truthy(next(result) == nil or result["/project/file.lua"] == nil) + end) + end) + end) + + -- Note: minimize_diff_blocks is a private function (in P table, not M table) + -- It's tested indirectly through extract_diff_blocks with Config.behaviour.minimize_diff = true + + describe("integration with Config.behaviour.minimize_diff", function() + it("should apply minimize_diff when config enabled", function() + Config.behaviour.minimize_diff = true + + local read_stub = stub( + Utils, + "read_file_from_buf_or_disk", + function() return sample_files.file_for_minimize_diff, nil end + ) + local fuzzy_stub = stub(Utils, "fuzzy_match", function() + return 1, 5 -- Match all 5 lines + end) + + local tool_call = { + content = { + { + type = "diff", + path = "/project/test.lua", + oldText = table.concat(sample_files.file_for_minimize_diff, "\n"), + newText = "CHANGED\nline 2 - keep me\nCHANGED\nline 4 - keep me\nCHANGED", + }, + }, + } + + local result = M.extract_diff_blocks(tool_call) + + read_stub:revert() + fuzzy_stub:revert() + + -- Should have multiple blocks (unchanged lines removed) + assert.truthy(#result["/project/test.lua"] > 1) + end) + end) +end)