Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 77 additions & 22 deletions internal/ai/gemini/commit_summarizer_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,71 @@ type (
}
)

// getCommitSuggestionSchema returns the JSON schema for commit suggestions
func getCommitSuggestionSchema() *genai.Schema {
return &genai.Schema{
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeObject,
Required: []string{"title", "desc", "files"},
Properties: map[string]*genai.Schema{
"title": {
Type: genai.TypeString,
Description: "Commit title (type(scope): message)",
},
"desc": {
Type: genai.TypeString,
Description: "Detailed explanation in first person",
},
"files": {
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeString,
},
Description: "Array of file paths as strings",
},
"analysis": {
Type: genai.TypeObject,
Required: []string{"overview", "purpose", "impact"},
Properties: map[string]*genai.Schema{
"overview": {Type: genai.TypeString},
"purpose": {Type: genai.TypeString},
"impact": {Type: genai.TypeString},
},
},
"requirements": {
Type: genai.TypeObject,
Required: []string{"status", "missing", "completed_indices", "suggestions"},
Properties: map[string]*genai.Schema{
"status": {
Type: genai.TypeString,
Enum: []string{"full_met", "partially_met", "not_met"},
},
"missing": {
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeString,
},
},
"completed_indices": {
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeInteger,
},
},
"suggestions": {
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeString,
},
},
},
},
},
},
}
}

func NewGeminiCommitSummarizer(ctx context.Context, cfg *config.Config, onConfirmation ai.ConfirmationCallback) (*GeminiCommitSummarizer, error) {
providerCfg, exists := cfg.AIProviders["gemini"]
if !exists || providerCfg.APIKey == "" {
Expand Down Expand Up @@ -97,13 +162,11 @@ func NewGeminiCommitSummarizer(ctx context.Context, cfg *config.Config, onConfir

func (s *GeminiCommitSummarizer) defaultGenerate(ctx context.Context, mName string, p string) (interface{}, *models.TokenUsage, error) {
log := logger.FromContext(ctx)

log.Debug("calling gemini API",
"model", mName,
"prompt_length", len(p))

genConfig := GetGenerateConfig(mName, "application/json")

schema := getCommitSuggestionSchema()
genConfig := GetGenerateConfig(mName, "application/json", schema)
resp, err := s.Client.Models.GenerateContent(ctx, mName, genai.Text(p), genConfig)
if err != nil {
log.Error("gemini API call failed",
Expand All @@ -116,23 +179,24 @@ func (s *GeminiCommitSummarizer) defaultGenerate(ctx context.Context, mName stri
strings.Contains(errMsg, "resource exhausted") {
return nil, nil, domainErrors.ErrGeminiQuotaExceeded.WithError(err)
}

if strings.Contains(errMsg, "invalid") ||
strings.Contains(errMsg, "unauthorized") ||
strings.Contains(errMsg, "api key") {
return nil, nil, domainErrors.ErrGeminiAPIKeyInvalid.WithError(err)
}

return nil, nil, domainErrors.ErrAIGeneration.WithError(err)
}

usage := extractUsage(resp)

log.Debug("gemini API response received",
"input_tokens", usage.InputTokens,
"output_tokens", usage.OutputTokens,
"candidates", len(resp.Candidates))

if usage != nil {
log.Debug("gemini API response received",
"input_tokens", usage.InputTokens,
"output_tokens", usage.OutputTokens,
"candidates", len(resp.Candidates))
} else {
log.Debug("gemini API response received",
"candidates", len(resp.Candidates),
"usage", "nil")
}
return resp, usage, nil
}

Expand Down Expand Up @@ -216,42 +280,33 @@ func (s *GeminiCommitSummarizer) parseSuggestionsJSON(responseText string) ([]mo
if responseText == "" {
return nil, fmt.Errorf("empty response text from AI")
}

responseText = ExtractJSON(responseText)

var jsonSuggestions []CommitSuggestionJSON
if err := json.Unmarshal([]byte(responseText), &jsonSuggestions); err != nil {
// Log at default level (no context available here)
return nil, fmt.Errorf("error parsing JSON: %w", err)
}

suggestions := make([]models.CommitSuggestion, 0, len(jsonSuggestions))
for _, js := range jsonSuggestions {
suggestion := models.CommitSuggestion{
CommitTitle: js.Title,
Explanation: js.Desc,
Files: js.Files,
}

if js.Analysis != nil {
suggestion.CodeAnalysis = models.CodeAnalysis{
ChangesOverview: js.Analysis.OverView,
PrimaryPurpose: js.Analysis.Purpose,
TechnicalImpact: js.Analysis.Impact,
}
}

if js.Requirements != nil {
suggestion.RequirementsAnalysis = models.RequirementsAnalysis{
CriteriaStatus: models.CriteriaStatus(js.Requirements.Status),
MissingCriteria: js.Requirements.Missing,
ImprovementSuggestions: js.Requirements.Suggestions,
}
}

suggestions = append(suggestions, suggestion)
}

return suggestions, nil
}

Expand Down
2 changes: 0 additions & 2 deletions internal/ai/gemini/commit_summarizer_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ func TestGeminiCommitSummarizer(t *testing.T) {
// assert
assert.Contains(t, prompt, "commit", "El prompt debería contener 'commit'")
assert.Contains(t, prompt, "Archivos Modificados", "El prompt debería contener 'Archivos modificados'")
assert.Contains(t, prompt, "Explicación", "El prompt debería contener 'Explicación'")
assert.Contains(t, prompt, "feat", "El prompt debería contener tipos de commit")
assert.Contains(t, prompt, "fix", "El prompt debería contener tipos de commit")
assert.Contains(t, prompt, "refactor", "El prompt debería contener tipos de commit")
Expand Down Expand Up @@ -237,7 +236,6 @@ func TestGeminiCommitSummarizer(t *testing.T) {
// assert
assert.Contains(t, prompt, "commit", "The prompt should contain 'commit'")
assert.Contains(t, prompt, "Modified Files", "The prompt should contain 'Modified files'")
assert.Contains(t, prompt, "explanation", "The prompt should contain 'Explanation'")
assert.Contains(t, prompt, "feat", "The prompt should contain commit types")
assert.Contains(t, prompt, "fix", "The prompt should contain commit types")
assert.Contains(t, prompt, "refactor", "The prompt should contain commit types")
Expand Down
109 changes: 4 additions & 105 deletions internal/ai/gemini/helper.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package gemini

import (
"encoding/json"
"strings"

"github.com/thomas-vilte/matecommit/internal/models"
"github.com/thomas-vilte/matecommit/internal/regex"
"google.golang.org/genai"
)

Expand All @@ -22,7 +20,7 @@ func extractUsage(resp *genai.GenerateContentResponse) *models.TokenUsage {
}

// GetGenerateConfig returns the optimal configuration for the model, enabling Thinking Mode if compatible.
func GetGenerateConfig(modelName string, responseType string) *genai.GenerateContentConfig {
func GetGenerateConfig(modelName string, responseType string, schema *genai.Schema) *genai.GenerateContentConfig {
config := &genai.GenerateContentConfig{
Temperature: float32Ptr(0.3),
MaxOutputTokens: int32(10000),
Expand All @@ -31,6 +29,9 @@ func GetGenerateConfig(modelName string, responseType string) *genai.GenerateCon

if responseType == "application/json" {
config.ResponseMIMEType = "application/json"
if schema != nil {
config.ResponseJsonSchema = schema
}
}

if strings.HasPrefix(modelName, "gemini-3") {
Expand All @@ -43,108 +44,6 @@ func GetGenerateConfig(modelName string, responseType string) *genai.GenerateCon
return config
}

// ExtractJSON attempts to extract a valid JSON block from text, handling Markdown code blocks
// and possible extra text that models with "Thinking" mode might generate.
func ExtractJSON(text string) string {
text = strings.TrimSpace(text)

matches := regex.MarkdownJSONBlock.FindAllStringSubmatch(text, -1)
var bestMarkdown string
for _, m := range matches {
if len(m) > 1 {
content := strings.TrimSpace(m[1])
sanitized := SanitizeJSON(content)
if json.Valid([]byte(sanitized)) {
if len(sanitized) > len(bestMarkdown) {
bestMarkdown = sanitized
}
}
}
}
if bestMarkdown != "" {
return bestMarkdown
}

var bestBlock string
for i := 0; i < len(text); {
startIdx := strings.IndexAny(text[i:], "{[")
if startIdx == -1 {
break
}
startIdx += i

opener := text[startIdx]
var closer byte
if opener == '{' {
closer = '}'
} else {
closer = ']'
}

count := 0
inString := false
escaped := false
foundEnd := false
endIdx := -1

for j := startIdx; j < len(text); j++ {
char := text[j]
if escaped {
escaped = false
continue
}
if char == '\\' {
escaped = true
continue
}
if char == '"' {
inString = !inString
continue
}

if !inString {
if char == opener {
count++
} else if char == closer {
count--
if count == 0 {
foundEnd = true
endIdx = j
break
}
}
}
}

if foundEnd {
block := text[startIdx : endIdx+1]
sanitized := SanitizeJSON(block)
if json.Valid([]byte(sanitized)) {
if len(sanitized) > len(bestBlock) {
bestBlock = sanitized
}
}
i = endIdx + 1
} else {
i = startIdx + 1
}
}

if bestBlock != "" {
return bestBlock
}

return SanitizeJSON(text)
}

// SanitizeJSON cleans malformed JSON that LLMs sometimes generate,
// such as unescaped newlines within String Literals.
func SanitizeJSON(s string) string {
return regex.JSONString.ReplaceAllStringFunc(s, func(m string) string {
return strings.ReplaceAll(m, "\n", "\\n")
})
}

func float32Ptr(f float32) *float32 {
return &f
}
Expand Down
Loading
Loading