Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ When you use `@copilot`, the LLM can call functions like `glob`, `file`, `gitdif
| - | `gh` | Show help message |

> [!WARNING]
> Some plugins (e.g. `copilot.vim`) may also map common keys like `<Tab>` in insert mode.
> Some plugins (e.g. `copilot.vim`) may also map common keys like `<Tab>` in insert mode.
> To avoid conflicts, disable Copilot's default `<Tab>` mapping with:
>
> ```lua
Expand Down Expand Up @@ -404,6 +404,21 @@ Add custom AI providers:
- `copilot` - GitHub Copilot (default)
- `github_models` - GitHub Marketplace models (disabled by default)

## Github Enterprise

If your employer provides access to Copilot via a Github Enterprise instance ("GHEC") you can provide the respective URLs with the following config keys:

```lua
{
-- github instance main address w/o protocol prefix, default: "github.com" (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
github_instance_url = 'mycorp.ghe.com',
-- github instance api address w/o protocol prefix, default: "api.github.com" (without "https://"). E.g.: "mycorp.ghe.com/api/v3"
github_instance_api_url = 'mycorp.ghe.com/api/v3',
}
```

(These keys are used in the default Copilot "provider", this is an alternative to defining a full custom provider)

# API Reference

## Core
Expand Down
12 changes: 7 additions & 5 deletions lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,21 @@ function Client:models()
ipairs(get_cached(self.provider_cache[provider_name], 'models', function()
notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name)

local ok, headers = pcall(self.authenticate, self, provider_name)
local ok, headers_or_err = pcall(self.authenticate, self, provider_name)
if not ok then
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
log.error('Failed to authenticate with ' .. provider_name .. ': ' .. headers_or_err)
error(headers_or_err)
return {}
end

local ok, models = pcall(provider.get_models, headers)
local ok, models_or_err = pcall(provider.get_models, headers_or_err)
if not ok then
log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. models)
log.error('Failed to fetch models from ' .. provider_name .. ': ' .. models_or_err)
error(models_or_err)
return {}
end

return models or {}
return models_or_err or {}
end))
do
model.provider = provider_name
Expand Down
5 changes: 5 additions & 0 deletions lua/CopilotChat/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
---@field functions table<string, CopilotChat.config.functions.Function>?
---@field prompts table<string, CopilotChat.config.prompts.Prompt|string>?
---@field mappings CopilotChat.config.mappings?
---@field github_instance_url string?
---@field github_instance_api_url string?
return {

-- Shared config starts here (can be passed to functions at runtime and configured via setup function)
Expand Down Expand Up @@ -102,6 +104,9 @@ return {

chat_autocomplete = true, -- Enable chat autocompletion (when disabled, requires manual `mappings.complete` trigger)

github_instance_url = 'github.com', -- github instance main address w/o protocol prefix (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
github_instance_api_url = 'api.github.com', -- github instance api address w/o protocol prefix (without "https://"). E.g.: "api.mycorp.ghe.com"

log_path = vim.fn.stdpath('state') .. '/CopilotChat.log', -- Default path to log file
history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history

Expand Down
101 changes: 68 additions & 33 deletions lua/CopilotChat/config/providers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ local constants = require('CopilotChat.constants')
local notify = require('CopilotChat.notify')
local utils = require('CopilotChat.utils')
local plenary_utils = require('plenary.async.util')
local log = require('plenary.log')

local EDITOR_VERSION = 'Neovim/' .. vim.version().major .. '.' .. vim.version().minor .. '.' .. vim.version().patch

---@class CopilotChat
---@field config CopilotChat.config.Config
---@field chat CopilotChat.ui.chat.Chat
local MC = setmetatable({}, {
__index = function(t, key)
if key == 'config' then
return require('CopilotChat.config')
end
return rawget(t, key)
end,
})

local token_cache = nil
local unsaved_token_cache = {}
local function load_tokens()
Expand Down Expand Up @@ -50,35 +63,48 @@ end
---@return string
local function github_device_flow(tag, client_id, scope)
local function request_device_code()
local res = utils.curl_post('https://github.com/login/device/code', {
local res, err = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/device/code', {
body = {
client_id = client_id,
scope = scope,
},
headers = { ['Accept'] = 'application/json' },
})

local data = vim.json.decode(res.body)
if not res then
error('failed to request device code: ' .. (err or 'unknown error'))
end

local ok, data = pcall(vim.json.decode, res.body)
if not ok then
error('failed to decode device code response: ' .. tostring(res.body))
end

return data
end

local function poll_for_token(device_code, interval)
while true do
plenary_utils.sleep(interval * 1000)

local res = utils.curl_post('https://github.com/login/oauth/access_token', {
local res, err = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/oauth/access_token', {
body = {
client_id = client_id,
device_code = device_code,
grant_type = 'urn:ietf:params:oauth:grant-type:device_code',
},
headers = { ['Accept'] = 'application/json' },
})

if not res then
error('failed polling for token: ' .. (err or 'unknown error'))
end

local data = vim.json.decode(res.body)
if data.access_token then
return data.access_token
elseif data.error ~= 'authorization_pending' then
error('Auth error: ' .. (data.error or 'unknown'))
error('auth error: ' .. (data.error or 'unknown'))
end
end
end
Expand Down Expand Up @@ -124,7 +150,6 @@ local function get_github_copilot_token(tag)
return token
end

-- loading token from the environment only in GitHub Codespaces
local codespaces = os.getenv('CODESPACES')
token = os.getenv('GITHUB_TOKEN')
if token and codespaces then
Expand All @@ -134,7 +159,6 @@ local function get_github_copilot_token(tag)
-- loading token from the file
local config_path = config_path()
if config_path then
-- token can be sometimes in apps.json sometimes in hosts.json
local file_paths = {
config_path .. '/github-copilot/hosts.json',
config_path .. '/github-copilot/apps.json',
Expand All @@ -146,7 +170,7 @@ local function get_github_copilot_token(tag)
local parsed_data = utils.json_decode(file_data)
if parsed_data then
for key, value in pairs(parsed_data) do
if string.find(key, 'github.com') and value and value.oauth_token then
if string.find(key, MC.config.github_instance_url) and value and value.oauth_token then
return set_token(tag, value.oauth_token, false)
end
end
Expand All @@ -155,7 +179,7 @@ local function get_github_copilot_token(tag)
end
end

return github_device_flow(tag, 'Iv1.b507a08c87ecfe98', '')
return github_device_flow(tag, '<your-enterprise-client-id>', '')
end

local function get_github_models_token(tag)
Expand All @@ -173,7 +197,7 @@ local function get_github_models_token(tag)

-- loading token from gh cli if available
if vim.fn.executable('gh') == 0 then
local result = utils.system({ 'gh', 'auth', 'token', '-h', 'github.com' })
local result = utils.system({ 'gh', 'auth', 'token', '-h', MC.config.github_instance_url })
if result and result.code == 0 and result.stdout then
local gh_token = vim.trim(result.stdout)
if gh_token ~= '' and not gh_token:find('no oauth token') then
Expand All @@ -182,7 +206,7 @@ local function get_github_models_token(tag)
end
end

return github_device_flow(tag, '178c6fc778ccc68e1d6a', 'read:user copilot')
return github_device_flow(tag, '<your-enterprise-client-id>', 'read:user copilot')
end

---@class CopilotChat.config.providers.Options
Expand All @@ -205,21 +229,35 @@ end
---@field prepare_input nil|fun(inputs:table<CopilotChat.client.Message>, opts:CopilotChat.config.providers.Options):table
---@field prepare_output nil|fun(output:table, opts:CopilotChat.config.providers.Options):CopilotChat.config.providers.Output
---@field get_url nil|fun(opts:CopilotChat.config.providers.Options):string
---@field endpoints_api string?

---@type table<string, CopilotChat.config.providers.Provider>
local M = {}

M.copilot = {
endpoints_api = '',

get_headers = function()
local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', {
local url = 'https://' .. MC.config.github_instance_api_url .. '/copilot_internal/v2/token'
log.debug('get headers - get ' .. url)
local response, err = utils.curl_get(url, {
json_response = true,
headers = {
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_api_url),
},
})

if err then
error(err)
if not response then
error('failed to fetch headers: ' .. (err or 'unknown error'))
end

if response.body and response.body.endpoints and response.body.endpoints.api then
log.info('get_headers ok, authenticated. Use api endpoint: ' .. response.body.endpoints.api)
M.endpoints_api = response.body.endpoints.api
else
log.error('get_headers authenticated, but missing key "endpoints.api" in server response. response: '
.. utils.to_string(response))
error('get_headers authenticated, but missing key "endpoints.api" in server response. check log for details')
end

return {
Expand All @@ -232,15 +270,15 @@ M.copilot = {
end,

get_info = function(headers)
local response, err = utils.curl_get('https://api.github.com/copilot_internal/user', {
local response, err = utils.curl_get('https://' .. MC.config.github_instance_url .. '/copilot_internal/user', {
json_response = true,
headers = {
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_url),
},
})

if err then
error(err)
if not response then
error('failed to get copilot info: ' .. (err or 'unknown error'))
end

local stats = response.body
Expand All @@ -251,12 +289,8 @@ M.copilot = {
end

local function usage_line(name, snap)
if not snap then
return
end

if not snap then return end
table.insert(lines, string.format(' **%s**', name))

if snap.unlimited then
table.insert(lines, ' Usage: Unlimited')
else
Expand All @@ -282,13 +316,14 @@ M.copilot = {
end,

get_models = function(headers)
local response, err = utils.curl_get('https://api.githubcopilot.com/models', {
log.info('getting models .. headers: ' .. utils.to_string(headers))
local response, err = utils.curl_get(M.endpoints_api .. '/models', {
json_response = true,
headers = headers,
})

if err then
error(err)
if not response then
error('failed to fetch models: ' .. (err or 'unknown error'))
end

local models = vim
Expand Down Expand Up @@ -322,7 +357,7 @@ M.copilot = {

for _, model in ipairs(models) do
if not model.policy then
utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', {
utils.curl_post(M.endpoints_api .. '/models/' .. model.id .. '/policy', {
headers = headers,
json_request = true,
body = { state = 'enabled' },
Expand Down Expand Up @@ -405,8 +440,8 @@ M.copilot = {

local choice
if output.choices and #output.choices > 0 then
for _, choice in ipairs(output.choices) do
local message = choice.message or choice.delta
for _, choice_item in ipairs(output.choices) do
local message = choice_item.message or choice_item.delta
if message and message.tool_calls then
for i, tool_call in ipairs(message.tool_calls) do
local fn = tool_call['function']
Expand All @@ -423,7 +458,6 @@ M.copilot = {
end
end
end

choice = output.choices[1]
else
choice = output
Expand All @@ -448,7 +482,7 @@ M.copilot = {
end,

get_url = function()
return 'https://api.githubcopilot.com/chat/completions'
return M.endpoints_api .. '/chat/completions'
end,
}

Expand All @@ -467,8 +501,8 @@ M.github_models = {
headers = headers,
})

if err then
error(err)
if not response then
error('failed to fetch github models: ' .. (err or 'unknown error'))
end

return vim
Expand Down Expand Up @@ -500,3 +534,4 @@ M.github_models = {
}

return M

38 changes: 38 additions & 0 deletions lua/CopilotChat/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,44 @@ M.curl_post = async.wrap(function(url, opts, callback)
curl.post(url, args)
end, 3)

function M.to_string(tbl)
-- credit: http://lua-users.org/wiki/TableSerialization (universal tostring)
local function table_print(tt, indent, done)
done = done or {}
indent = indent or 0
if type(tt) == 'table' then
local sb = {}
for key, value in pairs(tt) do
table.insert(sb, string.rep(' ', indent)) -- indent it
if type(value) == 'table' and not done[value] then
done[value] = true
table.insert(sb, key .. ' = {\n')
table.insert(sb, table_print(value, indent + 2, done))
table.insert(sb, string.rep(' ', indent)) -- indent it
table.insert(sb, '}\n')
elseif 'number' == type(key) then
table.insert(sb, string.format('"%s"\n', tostring(value)))
else
table.insert(sb, string.format('%s = "%s"\n', tostring(key), tostring(value)))
end
end
return table.concat(sb)
else
return tt .. '\n'
end
end

if 'nil' == type(tbl) then
return tostring(nil)
elseif 'table' == type(tbl) then
return table_print(tbl)
elseif 'string' == type(tbl) then
return tbl
else
return tostring(tbl)
end
end

local function filter_files(files, max_count)
local filetype = require('plenary.filetype')

Expand Down