diff --git a/README.md b/README.md index 554ed14..40fa634 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ perfect commit message. when you run `git commit -v` - 🎯 Works from terminal or within Neovim (using vim-fugitive) - 🤝 Non-intrusive - if you start typing, AI suggestions are added as comments instead -- 🔑 Uses `GEMINI_API_KEY`, `OPENAI_API_KEY`, or `ANTHROPIC_API_KEY` environment variables for authentication +- 🔑 Uses `GEMINI_API_KEY`, `OPENAI_API_KEY`, `COPILOT_TOKEN`, or `ANTHROPIC_API_KEY` environment variables for authentication +- 🔐 Copilot: automatically uses existing authentication from copilot.lua or copilot.vim (no token needed) - ⚙️ Configurable model, temperature, and max tokens - 🔄 Optional push prompt after successful commits - ⬇️⬆️ Pull before push to reduce rejections (configurable with args) @@ -91,7 +92,9 @@ export GEMINI_API_KEY="your-api-key-here" export OPENAI_API_KEY="your-api-key-here" ``` -**For Anthropic:** +**For Copilot:** + +If you're already using [copilot.lua](https://github.com/zbirenbaum/copilot.lua) or [copilot.vim](https://github.com/github/copilot.vim), authentication is automatic, no setup needed. Otherwise: ```bash export COPILOT_TOKEN="your-github-copilot-token-here" @@ -127,8 +130,8 @@ require("ai_commit_msg").setup({ args = { "--rebase", "--autostash" }, -- arguments passed to `git pull` }, - -- Show spinner while generating - spinner = true, + -- Show spinner while generating (true for default, false to disable, or custom frames array) + spinner = true, -- Can also be false or { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" } -- Show notifications notifications = true, @@ -352,7 +355,7 @@ git config --global core.editor nvim ## Tips -- The plugin uses Gemini API, OpenAI Chat Completions API, and Anthropic Messages API directly +- The plugin uses Gemini API, Copilot Chat Completions API, OpenAI Chat Completions API, and Anthropic Messages API directly - Lower temperature values (0.1-0.3) produce more consistent commit messages - Higher temperature values (0.5-0.8) produce more creative variations - The default model `gemini-2.5-flash-lite` provides excellent results at a very low cost diff --git a/lua/ai_commit_msg/config.lua b/lua/ai_commit_msg/config.lua index d826031..4cfa686 100644 --- a/lua/ai_commit_msg/config.lua +++ b/lua/ai_commit_msg/config.lua @@ -4,6 +4,9 @@ local M = {} local DEFAULT_PROMPT = [[{diff}]] local DEFAULT_SYSTEM_PROMPT = require("ai_commit_msg.prompts").DEFAULT_SYSTEM_PROMPT +-- Default spinner frames +M.DEFAULT_SPINNER_FRAMES = { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" } + ---@class ProviderConfig ---@field model string Model to use for this provider ---@field temperature number|nil Temperature for the model (0.0 to 1.0) @@ -25,7 +28,7 @@ local DEFAULT_SYSTEM_PROMPT = require("ai_commit_msg.prompts").DEFAULT_SYSTEM_PR ---@field providers table Provider-specific configurations ---@field auto_push_prompt boolean Whether to prompt for push after commit ---@field pull_before_push { enabled: boolean, args: string[] } Whether and how to run `git pull` before pushing ----@field spinner boolean Whether to show a spinner while generating +---@field spinner string[]|boolean Array of spinner frames to animate, true for default frames, or false to disable spinner ---@field notifications boolean Whether to show notifications ---@field context_lines number Number of surrounding lines to include in git diff ---@field keymaps table Keymaps for commit buffer diff --git a/lua/ai_commit_msg/generator.lua b/lua/ai_commit_msg/generator.lua index d26e398..836810a 100644 --- a/lua/ai_commit_msg/generator.lua +++ b/lua/ai_commit_msg/generator.lua @@ -1,8 +1,7 @@ local M = {} -local function get_spinner() - local spinner = { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" } - return spinner[math.floor(vim.uv.hrtime() / (1e6 * 80)) % #spinner + 1] +local function get_spinner(spinner_frames) + return spinner_frames[math.floor(vim.uv.hrtime() / (1e6 * 80)) % #spinner_frames + 1] end local function notify(msg, level, config) @@ -24,19 +23,34 @@ function M.generate(config, callback) end) local spinner_timer - local notif_id = "ai-commit-msg" + local notify_record + local notify_called = false + + -- Resolve spinner frames: true uses default, table uses custom, false disables + local spinner_frames = nil + if config.spinner == true then + spinner_frames = require("ai_commit_msg.config").DEFAULT_SPINNER_FRAMES + elseif type(config.spinner) == "table" and #config.spinner > 0 then + spinner_frames = config.spinner + end -- Start spinner if enabled - if config.spinner and config.notifications then + if spinner_frames and config.notifications then local function update_spinner() if not spinner_timer or spinner_timer:is_closing() then return end - vim.notify(get_spinner() .. " Generating commit message...", vim.log.levels.INFO, { - id = notif_id, + local opts = { title = "AI Commit", timeout = false, - }) + hide_from_history = notify_called, + } + if notify_record then + opts.replace = notify_record.id + end + notify_record = + vim.notify(get_spinner(spinner_frames) .. " Generating commit message...", vim.log.levels.INFO, opts) + notify_called = true end spinner_timer = vim.uv.new_timer() @@ -66,13 +80,16 @@ function M.generate(config, callback) vim.schedule(function() local error_msg = "Failed to get git diff: " .. (diff_res.stderr or "Unknown error") vim.notify("ai-commit-msg.nvim: " .. error_msg, vim.log.levels.ERROR) - -- Clear spinner notification with error message + -- Replace spinner notification with error message if config.notifications then - vim.notify("❌ " .. error_msg, vim.log.levels.ERROR, { - id = notif_id, + local opts = { title = "AI Commit", timeout = 3000, - }) + } + if notify_record then + opts.replace = notify_record.id + end + vim.notify("❌ " .. error_msg, vim.log.levels.ERROR, opts) end if callback then callback(false, error_msg) @@ -93,13 +110,16 @@ function M.generate(config, callback) vim.schedule(function() local error_msg = "No staged changes to commit" vim.notify("ai-commit-msg.nvim: " .. error_msg, vim.log.levels.WARN) - -- Clear spinner notification with warning message + -- Replace spinner notification with warning message if config.notifications then - vim.notify("⚠️ " .. error_msg, vim.log.levels.WARN, { - id = notif_id, + local opts = { title = "AI Commit", timeout = 3000, - }) + } + if notify_record then + opts.replace = notify_record.id + end + vim.notify("⚠️ " .. error_msg, vim.log.levels.WARN, opts) end if callback then callback(false, error_msg) @@ -125,13 +145,16 @@ function M.generate(config, callback) vim.schedule(function() if not success then vim.notify("ai-commit-msg.nvim: " .. result, vim.log.levels.ERROR) - -- Clear spinner notification with error message + -- Replace spinner notification with error message if config.notifications then - vim.notify("❌ " .. result, vim.log.levels.ERROR, { - id = notif_id, + local opts = { title = "AI Commit", timeout = 3000, - }) + } + if notify_record then + opts.replace = notify_record.id + end + vim.notify("❌ " .. result, vim.log.levels.ERROR, opts) end if callback then callback(false, result) @@ -153,13 +176,16 @@ function M.generate(config, callback) end vim.notify("ai-commit-msg.nvim: Generated message: " .. result:sub(1, 50) .. "...", vim.log.levels.DEBUG) - -- Clear spinner notification with success message + -- Replace spinner notification with success message if config.notifications then - vim.notify("✅ Commit message generated (" .. duration_cost_str .. ")", vim.log.levels.INFO, { - id = notif_id, + local opts = { title = "AI Commit", timeout = 2000, - }) + } + if notify_record then + opts.replace = notify_record.id + end + vim.notify("✅ Commit message generated (" .. duration_cost_str .. ")", vim.log.levels.INFO, opts) end if callback then callback(true, result) diff --git a/lua/ai_commit_msg/providers/copilot.lua b/lua/ai_commit_msg/providers/copilot.lua index 5f61dcc..e27bf9f 100644 --- a/lua/ai_commit_msg/providers/copilot.lua +++ b/lua/ai_commit_msg/providers/copilot.lua @@ -1,4 +1,5 @@ local M = {} +local EDITOR_VERSION = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch -- Models that support reasoning_effort parameter local REASONING_EFFORT_MODELS = { @@ -7,16 +8,197 @@ local REASONING_EFFORT_MODELS = { ["gpt-5"] = true, } +-- Cache for OAuth and Copilot tokens +local _oauth_token = nil +local _copilot_token = nil +local _copilot_endpoints = nil +local _token_fetch_in_progress = false + local function model_supports_reasoning_effort(model) return REASONING_EFFORT_MODELS[model] or model:match("^gpt%-5") end +-- Find the configuration path for GitHub Copilot +local function find_config_path() + local path = os.getenv("XDG_CONFIG_HOME") + if path and vim.uv.fs_stat(path) then + return path + end + + local home = os.getenv("HOME") or os.getenv("USERPROFILE") + if not home then + return nil + end + + if vim.fn.has("win32") == 1 then + path = home .. "/AppData/Local" + if vim.uv.fs_stat(path) then + return path + end + else + path = home .. "/.config" + if vim.uv.fs_stat(path) then + return path + end + end + return nil +end + +-- Get OAuth token from environment or config files +local function get_oauth_token() + if _oauth_token then + return _oauth_token + end + + -- Check for GitHub Codespaces environment + local token = os.getenv("GITHUB_TOKEN") + local codespaces = os.getenv("CODESPACES") + if token and codespaces then + _oauth_token = token + return _oauth_token + end + + -- Look for token in config files + local config_path = find_config_path() + if not config_path then + return nil + end + + local file_paths = { + config_path .. "/github-copilot/hosts.json", + config_path .. "/github-copilot/apps.json", + } + + for _, file_path in ipairs(file_paths) do + local stat = vim.uv.fs_stat(file_path) + if stat and stat.type == "file" then + local fd = vim.uv.fs_open(file_path, "r", 438) + if fd then + local stat_result = vim.uv.fs_fstat(fd) + if stat_result then + local content = vim.uv.fs_read(fd, stat_result.size, 0) + vim.uv.fs_close(fd) + + if content then + local ok_decode, data = pcall(vim.json.decode, content) + if ok_decode and type(data) == "table" then + for key, value in pairs(data) do + if key:find("github.com") and type(value) == "table" and value.oauth_token then + _oauth_token = value.oauth_token + return _oauth_token + end + end + end + end + else + vim.uv.fs_close(fd) + end + end + end + end + + return nil +end + +-- Exchange OAuth token for Copilot token +local function get_copilot_token(callback) + -- Check if we have a valid cached token + if _copilot_token and _copilot_token.expires_at and _copilot_token.expires_at > os.time() then + callback(true, _copilot_token.token, _copilot_endpoints) + return + end + + -- Wait if another fetch is in progress + if _token_fetch_in_progress then + local max_wait = 100 -- 5 seconds (100 * 50ms) + local waited = 0 + vim.wait(50, function() + waited = waited + 1 + if waited >= max_wait then + return true + end + return _copilot_token ~= nil and _copilot_token.expires_at ~= nil and _copilot_token.expires_at > os.time() + end, 50) + + if _copilot_token and _copilot_token.expires_at and _copilot_token.expires_at > os.time() then + callback(true, _copilot_token.token, _copilot_endpoints) + return + end + end + + _token_fetch_in_progress = true + + local oauth_token = _oauth_token + if not oauth_token then + _token_fetch_in_progress = false + callback(false, "No OAuth token found") + return + end + + local curl_args = { + "curl", + "-X", + "GET", + "https://api.github.com/copilot_internal/v2/token", + "-H", + "Authorization: Bearer " .. oauth_token, + "-H", + "Accept: application/json", + "--silent", + "--show-error", + } + + vim.system(curl_args, {}, function(res) + _token_fetch_in_progress = false + + if res.code ~= 0 then + callback(false, "Failed to get Copilot token: " .. (res.stderr or "Unknown error")) + return + end + + local ok, token_data = pcall(vim.json.decode, res.stdout) + if not ok or type(token_data) ~= "table" then + callback(false, "Failed to parse Copilot token response") + return + end + + _copilot_token = token_data + _copilot_endpoints = token_data.endpoints + callback(true, token_data.token, token_data.endpoints) + end) +end + -- Copilot provider using GitHub Models API chat completions --- Reads token from `config.token` (no env var usage) function M.call_api(config, diff, callback) - local token = os.getenv("COPILOT_TOKEN") + -- First try COPILOT_TOKEN env var + local env_token = os.getenv("COPILOT_TOKEN") + + if env_token and env_token ~= "" then + M._make_api_call(env_token, nil, config, diff, callback) + return + end + + -- Fallback to OAuth token mechanism + local oauth_token = get_oauth_token() + if not oauth_token then + callback(false, "No Copilot token found. Set COPILOT_TOKEN env var or authenticate with GitHub Copilot") + return + end + + get_copilot_token(function(success, token, endpoints) + if not success then + callback(false, token) -- token contains error message here + return + end + + M._make_api_call(token, endpoints, config, diff, callback) + end) +end + +-- Internal function to make the actual API call +function M._make_api_call(token, endpoints, config, diff, callback) if not token or token == "" then - callback(false, "Copilot token not set in config") + callback(false, "Invalid Copilot token") return end @@ -68,15 +250,27 @@ function M.call_api(config, diff, callback) local payload = vim.json.encode(payload_data) + -- Use endpoint from Copilot token if available, otherwise use default + local api_url = "https://models.github.ai/inference/chat/completions" + if endpoints and endpoints.api then + api_url = endpoints.api .. "/chat/completions" + end + local curl_args = { "curl", "-X", "POST", - "https://models.github.ai/inference/chat/completions", + api_url, "-H", "Content-Type: application/json", "-H", "Authorization: Bearer " .. token, + "-H", + "Editor-Version: " .. EDITOR_VERSION, + "-H", + "Editor-Plugin-Version: ai-commit-msg.nvim/*", + "-H", + "Copilot-Integration-Id: vscode-chat", "-d", payload, "--silent", diff --git a/spec/config_spec.lua b/spec/config_spec.lua index 9eb8f02..de9cb62 100644 --- a/spec/config_spec.lua +++ b/spec/config_spec.lua @@ -24,7 +24,7 @@ describe("ai_commit_msg config", function() assert.is_number(config.temperature) assert.is_string(config.prompt) assert.is_boolean(config.auto_push_prompt) - assert.is_boolean(config.spinner) + assert.is_true(type(config.spinner) == "boolean" or type(config.spinner) == "table") assert.is_boolean(config.notifications) assert.is_table(config.keymaps) end)