diff --git a/internal/ai/gemini/helpers.go b/internal/ai/gemini/helpers.go new file mode 100644 index 0000000..74c6992 --- /dev/null +++ b/internal/ai/gemini/helpers.go @@ -0,0 +1,41 @@ +package gemini + +import ( + "strings" +) + +// CleanLabels cleans and validates labels, keeping only the allowed ones. +// It accepts a list of labels to clean and a list of available labels from the repository. +// If availableLabels is empty, it falls back to a default list of common labels. +func CleanLabels(labels []string, availableLabels []string) []string { + allowedLabels := make(map[string]bool) + + if len(availableLabels) > 0 { + for _, l := range availableLabels { + allowedLabels[strings.ToLower(l)] = true + } + } else { + // Fallback to default list if no repo labels provided + defaultLabels := []string{ + "feature", "fix", "refactor", "docs", "test", "infra", + "enhancement", "bug", "good first issue", "help wanted", + "chore", "performance", "security", "tech-debt", "breaking-change", + } + for _, l := range defaultLabels { + allowedLabels[l] = true + } + } + + cleaned := make([]string, 0) + seen := make(map[string]bool) + + for _, label := range labels { + trimmed := strings.TrimSpace(strings.ToLower(label)) + if trimmed != "" && allowedLabels[trimmed] && !seen[trimmed] { + cleaned = append(cleaned, trimmed) + seen[trimmed] = true + } + } + + return cleaned +} diff --git a/internal/ai/gemini/issue_content_generator.go b/internal/ai/gemini/issue_content_generator.go index 77cb508..90364c5 100644 --- a/internal/ai/gemini/issue_content_generator.go +++ b/internal/ai/gemini/issue_content_generator.go @@ -72,8 +72,33 @@ func NewGeminiIssueContentGenerator(ctx context.Context, cfg *config.Config, onC return service, nil } +func getIssueSchema() *genai.Schema { + return &genai.Schema{ + Type: genai.TypeObject, + Required: []string{"title", "description", "labels"}, + Properties: map[string]*genai.Schema{ + "title": { + Type: genai.TypeString, + Description: "The title of the issue", + }, + "description": { + Type: genai.TypeString, + Description: "The body of the issue in markdown format", + }, + "labels": { + Type: genai.TypeArray, + Items: &genai.Schema{ + Type: genai.TypeString, + }, + Description: "List of labels (e.g. bug, feature, refactor, good first issue)", + }, + }, + } +} + func (s *GeminiIssueContentGenerator) defaultGenerate(ctx context.Context, mName string, p string) (interface{}, *models.TokenUsage, error) { - genConfig := GetGenerateConfig(mName, "", nil) + schema := getIssueSchema() + genConfig := GetGenerateConfig(mName, "application/json", schema) log := logger.FromContext(ctx) resp, err := s.Client.Models.GenerateContent(ctx, mName, genai.Text(p), genConfig) @@ -167,6 +192,7 @@ func (s *GeminiIssueContentGenerator) GenerateIssueContent(ctx context.Context, return nil, domainErrors.NewAppError(domainErrors.TypeAI, "error parsing AI response", err) } + result.Labels = CleanLabels(result.Labels, request.AvailableLabels) result.Usage = usage log.Info("issue content generated successfully via gemini", @@ -179,7 +205,7 @@ func (s *GeminiIssueContentGenerator) GenerateIssueContent(ctx context.Context, // buildIssuePrompt builds the prompt to generate issue content. func (s *GeminiIssueContentGenerator) buildIssuePrompt(request models.IssueGenerationRequest) string { if request.Description != "" && request.Diff == "" && request.Hint == "" && - request.Template == nil && len(request.ChangedFiles) == 0 { + request.Template == nil && len(request.ChangedFiles) == 0 && len(request.AvailableLabels) == 0 { return request.Description } @@ -227,37 +253,8 @@ func (s *GeminiIssueContentGenerator) buildIssuePrompt(request models.IssueGener return "" } - if request.Template != nil { - rendered += ` - -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨 - -YOU MUST OUTPUT **ONLY** VALID JSON. - -The template structure above should be used to FILL the "description" field with markdown content. - -BUT your actual response MUST be a JSON object like this: -{ - "title": "string here", - "description": "markdown content following the template structure", - "labels": ["array", "of", "strings"] -} - -❌ DO NOT output prose like "Here is a high-quality GitHub issue..." -❌ DO NOT output markdown text directly -❌ DO NOT output explanations - -✅ ONLY output the JSON object -✅ Use the template to structure the markdown in the "description" field -✅ Return valid parseable JSON - -BEGIN YOUR JSON OUTPUT NOW:` - - logger.Debug(context.Background(), "full prompt with template and final JSON reminder", - "prompt_length", len(rendered), - "prompt", rendered) + if len(request.AvailableLabels) > 0 { + rendered += fmt.Sprintf("\n\nAvailable Labels (Select ONLY from this list):\n%s", strings.Join(request.AvailableLabels, ", ")) } return rendered @@ -270,6 +267,8 @@ func (s *GeminiIssueContentGenerator) parseIssueResponse(content string) (*model return nil, domainErrors.NewAppError(domainErrors.TypeAI, "empty response from AI", nil) } + content = strings.TrimSpace(content) + if len(content) > 0 { preview := content if len(content) > 200 { @@ -307,40 +306,12 @@ func (s *GeminiIssueContentGenerator) parseIssueResponse(content string) (*model result := &models.IssueGenerationResult{ Title: strings.TrimSpace(jsonResult.Title), Description: strings.TrimSpace(jsonResult.Description), - Labels: s.cleanLabels(jsonResult.Labels), + Labels: jsonResult.Labels, } if result.Title == "" { result.Title = "Generated Issue" } - if result.Description == "" { - result.Description = content - } return result, nil } - -// cleanLabels cleans and validates labels, keeping only the allowed ones. -func (s *GeminiIssueContentGenerator) cleanLabels(labels []string) []string { - allowedLabels := map[string]bool{ - "feature": true, - "fix": true, - "refactor": true, - "docs": true, - "test": true, - "infra": true, - } - - cleaned := make([]string, 0) - seen := make(map[string]bool) - - for _, label := range labels { - trimmed := strings.TrimSpace(strings.ToLower(label)) - if trimmed != "" && allowedLabels[trimmed] && !seen[trimmed] { - cleaned = append(cleaned, trimmed) - seen[trimmed] = true - } - } - - return cleaned -} diff --git a/internal/ai/gemini/issue_content_generator_test.go b/internal/ai/gemini/issue_content_generator_test.go index 6530d7d..cd28af9 100644 --- a/internal/ai/gemini/issue_content_generator_test.go +++ b/internal/ai/gemini/issue_content_generator_test.go @@ -71,6 +71,15 @@ func TestBuildIssuePrompt(t *testing.T) { }, contains: []string{"Code Changes (git diff)", "user description", "special hint"}, }, + { + name: "with available labels", + request: models.IssueGenerationRequest{ + Description: "user description", + Language: "en", + AvailableLabels: []string{"bug", "enhancement"}, + }, + contains: []string{"Available Labels", "bug, enhancement"}, + }, } for _, tt := range tests { @@ -106,26 +115,6 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) { // Should contain the template assert.Contains(t, prompt, "Bug Report") - - // Should contain the final JSON reminder - assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨") - assert.Contains(t, prompt, "YOU MUST OUTPUT **ONLY** VALID JSON") - assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:") - - // Should contain instructions about using template in description field - assert.Contains(t, prompt, "The template structure above should be used to FILL the \"description\" field") - - // Should contain prohibitions - assert.Contains(t, prompt, "❌ DO NOT output prose like \"Here is a high-quality GitHub issue...\"") - assert.Contains(t, prompt, "❌ DO NOT output markdown text directly") - - // Verify the reminder is at the end - lastIndex := len(prompt) - 500 - if lastIndex < 0 { - lastIndex = 0 - } - finalSection := prompt[lastIndex:] - assert.Contains(t, finalSection, "BEGIN YOUR JSON OUTPUT NOW:") }) t.Run("does NOT add final reminder when no template", func(t *testing.T) { @@ -137,9 +126,8 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) { prompt := gen.buildIssuePrompt(request) - // Should NOT contain the final JSON reminder - assert.NotContains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨") - assert.NotContains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:") + // Verification is just that prompt exists and is relevant + assert.Contains(t, prompt, "Code Changes") }) t.Run("includes template in Spanish", func(t *testing.T) { @@ -159,10 +147,6 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) { // Should contain the template assert.Contains(t, prompt, "Reporte de Bug") - - // Should still contain the final JSON reminder (in English for consistency) - assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨") - assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:") }) t.Run("handles template with all fields", func(t *testing.T) { @@ -188,28 +172,11 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) { // Should contain changed files assert.Contains(t, prompt, "main.go") assert.Contains(t, prompt, "test.go") - - // Should contain the final reminder - assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨") - assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:") }) t.Run("reminder contains complete JSON structure example", func(t *testing.T) { - template := &models.IssueTemplate{ - Name: "Test Template", - } - - request := models.IssueGenerationRequest{ - Template: template, - Language: "en", - } - - prompt := gen.buildIssuePrompt(request) - - // Should show the expected JSON structure - assert.Contains(t, prompt, `"title": "string here"`) - assert.Contains(t, prompt, `"description": "markdown content following the template structure"`) - assert.Contains(t, prompt, `"labels": ["array", "of", "strings"]`) + // This test is now obsolete as structure is enforced by Schema, not prompt text. + // We can remove it or just check nothing. }) } @@ -268,33 +235,42 @@ func TestParseIssueResponse(t *testing.T) { } func TestCleanLabels(t *testing.T) { - gen := &GeminiIssueContentGenerator{} tests := []struct { - name string - input []string - expected []string + name string + input []string + availableLabels []string + expected []string }{ { - name: "only allowed labels", - input: []string{"fix", "feature", "bug", "invalid"}, - expected: []string{"fix", "feature"}, + name: "default whitelist - allowed", + input: []string{"fix", "feature", "bug", "invalid"}, + availableLabels: nil, + expected: []string{"fix", "feature", "bug"}, + }, + { + name: "default whitelist - mixed case", + input: []string{" Fix ", "FEATURE", "test"}, + availableLabels: nil, + expected: []string{"fix", "feature", "test"}, }, { - name: "mixed case and spaces", - input: []string{" Fix ", "FEATURE", "test"}, - expected: []string{"fix", "feature", "test"}, + name: "strict available labels", + input: []string{"custom-1", "custom-2", "fix"}, + availableLabels: []string{"custom-1", "custom-2"}, + expected: []string{"custom-1", "custom-2"}, }, { - name: "duplicates", - input: []string{"fix", "fix", "FIX"}, - expected: []string{"fix"}, + name: "strict available labels - excludes non-existent", + input: []string{"custom-1", "random"}, + availableLabels: []string{"custom-1"}, + expected: []string{"custom-1"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := gen.cleanLabels(tt.input) + result := CleanLabels(tt.input, tt.availableLabels) assert.ElementsMatch(t, tt.expected, result) }) } diff --git a/internal/ai/gemini/pull_requests_summarizer_service.go b/internal/ai/gemini/pull_requests_summarizer_service.go index d1a5526..21b8422 100644 --- a/internal/ai/gemini/pull_requests_summarizer_service.go +++ b/internal/ai/gemini/pull_requests_summarizer_service.go @@ -126,13 +126,14 @@ func (gps *GeminiPRSummarizer) defaultGenerate(ctx context.Context, mName string return resp, usage, nil } -func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent string) (models.PRSummary, error) { +func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent string, availableLabels []string) (models.PRSummary, error) { log := logger.FromContext(ctx) log.Info("generating PR summary via gemini", - "content_length", len(prContent)) + "content_length", len(prContent), + "available_labels_count", len(availableLabels)) - prompt := gps.generatePRPrompt(prContent) + prompt := gps.generatePRPrompt(prContent, availableLabels) log.Debug("calling gemini API for PR summary", "prompt_length", len(prompt)) @@ -199,12 +200,12 @@ func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent return models.PRSummary{ Title: jsonSummary.Title, Body: jsonSummary.Body, - Labels: jsonSummary.Labels, + Labels: CleanLabels(jsonSummary.Labels, availableLabels), Usage: usage, }, nil } -func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string) string { +func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string, availableLabels []string) string { templateStr := ai.GetPRPromptTemplate(gps.config.Language) data := ai.PromptData{ PRContent: prContent, @@ -215,5 +216,9 @@ func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string) string { return "" } + if len(availableLabels) > 0 { + rendered += fmt.Sprintf("\n\nAvailable Labels (Select ONLY from this list):\n%s", strings.Join(availableLabels, ", ")) + } + return rendered } diff --git a/internal/ai/gemini/pull_requests_summarizer_service_test.go b/internal/ai/gemini/pull_requests_summarizer_service_test.go index f404f26..5c8d511 100644 --- a/internal/ai/gemini/pull_requests_summarizer_service_test.go +++ b/internal/ai/gemini/pull_requests_summarizer_service_test.go @@ -44,7 +44,7 @@ func TestGeminiPRSummarizer(t *testing.T) { assert.NoError(t, err, "Error creando summarizer") // Act - summary, err := summarizer.GeneratePRSummary(ctx, "") + summary, err := summarizer.GeneratePRSummary(ctx, "", nil) // Assert assert.Equal(t, models.PRSummary{}, summary, "No deberían generarse resúmenes con prompt vacío") @@ -65,7 +65,7 @@ func TestGeminiPRSummarizer(t *testing.T) { prContent := "Some PR content to summarize" // Act - prompt := summarizer.generatePRPrompt(prContent) + prompt := summarizer.generatePRPrompt(prContent, nil) // Assert assert.Contains(t, prompt, "Some PR content to summarize", "El prompt debe contener el contenido del PR") @@ -152,7 +152,7 @@ func TestGeneratePRSummary_HappyPath(t *testing.T) { }, &models.TokenUsage{TotalTokens: 50}, nil } - summary, err := summarizer.GeneratePRSummary(ctx, "successful content") + summary, err := summarizer.GeneratePRSummary(ctx, "successful content", nil) assert.NoError(t, err) assert.Equal(t, "Awesome Feature", summary.Title) @@ -170,7 +170,7 @@ func TestGeneratePRSummary_HappyPath(t *testing.T) { }, &models.TokenUsage{}, nil } - summary, err := summarizer.GeneratePRSummary(ctx, "content with empty title") + summary, err := summarizer.GeneratePRSummary(ctx, "content with empty title", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid AI output format") diff --git a/internal/ai/interfaces.go b/internal/ai/interfaces.go index 8f5d616..08b2d63 100644 --- a/internal/ai/interfaces.go +++ b/internal/ai/interfaces.go @@ -15,7 +15,7 @@ type CommitSummarizer interface { // PRSummarizer defines the interface for services that summarize Pull Requests. type PRSummarizer interface { // GeneratePRSummary generates a summary of a Pull Request given a prompt. - GeneratePRSummary(ctx context.Context, prompt string) (models.PRSummary, error) + GeneratePRSummary(ctx context.Context, prompt string, availableLabels []string) (models.PRSummary, error) } // ReleaseNotesGenerator defines the interface to generate release notes. diff --git a/internal/models/issue_generation.go b/internal/models/issue_generation.go index 1d416a8..cec0e26 100644 --- a/internal/models/issue_generation.go +++ b/internal/models/issue_generation.go @@ -20,6 +20,9 @@ type IssueGenerationRequest struct { // Template is the project's issue template to guide generation (optional) Template *IssueTemplate + + // AvailableLabels is the list of labels available in the repository (optional) + AvailableLabels []string } // IssueGenerationResult contains the result of an issue's content generation. diff --git a/internal/services/issue_generator_service.go b/internal/services/issue_generator_service.go index 3d89806..3723982 100644 --- a/internal/services/issue_generator_service.go +++ b/internal/services/issue_generator_service.go @@ -108,16 +108,26 @@ func (s *IssueGeneratorService) GenerateFromDiff(ctx context.Context, hint strin template, _ = s.SelectTemplateWithAI(ctx, "", hint, changedFiles, nil) } + var availableLabels []string + if s.vcsClient != nil { + availableLabels, err = s.fetchAvailableLabels(ctx) + if err != nil { + logger.Warn(ctx, "failed to fetch repo labels, proceeding without them", err) + } + } + request := models.IssueGenerationRequest{ - Diff: diff, - ChangedFiles: changedFiles, - Hint: hint, - Language: s.config.Language, - Template: template, + Diff: diff, + ChangedFiles: changedFiles, + Hint: hint, + Language: s.config.Language, + Template: template, + AvailableLabels: availableLabels, } logger.Debug(ctx, "calling AI for issue generation from diff", - "has_template", template != nil) + "has_template", template != nil, + "available_labels_count", len(availableLabels)) result, err := s.ai.GenerateIssueContent(ctx, request) if err != nil { @@ -144,6 +154,13 @@ func (s *IssueGeneratorService) GenerateFromDiff(ctx context.Context, hint strin return result, nil } +func (s *IssueGeneratorService) fetchAvailableLabels(ctx context.Context) ([]string, error) { + if s.vcsClient == nil { + return nil, nil + } + return s.vcsClient.GetRepoLabels(ctx) +} + // GenerateFromDescription generates issue content based on a manual description. // Useful when the user wants to create an issue without having local changes. func (s *IssueGeneratorService) GenerateFromDescription(ctx context.Context, description string, skipLabels bool, autoTemplate bool) (*models.IssueGenerationResult, error) { @@ -170,16 +187,27 @@ func (s *IssueGeneratorService) GenerateFromDescription(ctx context.Context, des } } + var availableLabels []string + if s.vcsClient != nil { + var err error + availableLabels, err = s.fetchAvailableLabels(ctx) + if err != nil { + logger.Warn(ctx, "failed to fetch repo labels, proceeding without them", err) + } + } + request := models.IssueGenerationRequest{ - Description: description, - Template: template, + Description: description, + Template: template, + AvailableLabels: availableLabels, } if s.config != nil { request.Language = s.config.Language } logger.Debug(ctx, "calling AI for issue generation from description", - "has_template", template != nil) + "has_template", template != nil, + "available_labels_count", len(availableLabels)) result, err := s.ai.GenerateIssueContent(ctx, request) if err != nil { @@ -254,13 +282,23 @@ func (s *IssueGeneratorService) GenerateFromPR(ctx context.Context, prNumber int template, _ = s.SelectTemplateWithAI(ctx, prData.Title, prData.Description, changedFiles, prData.Labels) } + var availableLabels []string + if s.vcsClient != nil { + var err error + availableLabels, err = s.fetchAvailableLabels(ctx) + if err != nil { + logger.Warn(ctx, "failed to fetch repo labels, proceeding without them", err) + } + } + request := models.IssueGenerationRequest{ - Description: contextBuilder.String(), - Diff: prData.Diff, - ChangedFiles: changedFiles, - Hint: hint, - Language: s.config.Language, - Template: template, + Description: contextBuilder.String(), + Diff: prData.Diff, + ChangedFiles: changedFiles, + Hint: hint, + Language: s.config.Language, + Template: template, + AvailableLabels: availableLabels, } logger.Debug(ctx, "calling AI for issue generation from PR", diff --git a/internal/services/issue_generator_service_test.go b/internal/services/issue_generator_service_test.go index b65895c..f591378 100644 --- a/internal/services/issue_generator_service_test.go +++ b/internal/services/issue_generator_service_test.go @@ -339,6 +339,7 @@ index 0000000..1234567 +};`, } mockVCS.On("GetPR", ctx, 42).Return(prData, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) expectedResult := &models.IssueGenerationResult{ Title: "Implement user profile page", @@ -372,6 +373,7 @@ index 0000000..1234567 Diff: "diff --git a/cache/service.go b/cache/service.go\n...", } mockVCS.On("GetPR", ctx, 123).Return(prData, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) expectedResult := &models.IssueGenerationResult{ Title: "Memory leak in cache service", diff --git a/internal/services/mocks.go b/internal/services/mocks.go index 9444023..2ba3763 100644 --- a/internal/services/mocks.go +++ b/internal/services/mocks.go @@ -241,8 +241,8 @@ func (m *MockVCSClient) GetAuthenticatedUser(ctx context.Context) (string, error return args.String(0), args.Error(1) } -func (m *MockPRSummarizer) GeneratePRSummary(ctx context.Context, prompt string) (models.PRSummary, error) { - args := m.Called(ctx, prompt) +func (m *MockPRSummarizer) GeneratePRSummary(ctx context.Context, prompt string, availableLabels []string) (models.PRSummary, error) { + args := m.Called(ctx, prompt, availableLabels) return args.Get(0).(models.PRSummary), args.Error(1) } diff --git a/internal/services/pull_request_service.go b/internal/services/pull_request_service.go index 2d53465..8ddb7ad 100644 --- a/internal/services/pull_request_service.go +++ b/internal/services/pull_request_service.go @@ -17,11 +17,12 @@ type prVCSClient interface { GetPR(ctx context.Context, prNumber int) (models.PRData, error) GetPRIssues(ctx context.Context, branchName string, commitMessages []string, description string) ([]models.Issue, error) UpdatePR(ctx context.Context, prNumber int, summary models.PRSummary) error + GetRepoLabels(ctx context.Context) ([]string, error) } // prAIProvider defines the methods needed by PRService from an AI provider. type prAIProvider interface { - GeneratePRSummary(ctx context.Context, prompt string) (models.PRSummary, error) + GeneratePRSummary(ctx context.Context, prompt string, availableLabels []string) (models.PRSummary, error) } // prTemplateService defines the methods needed by PRService for template management. @@ -146,7 +147,16 @@ func (s *PRService) SummarizePR(ctx context.Context, prNumber int, hint string, log.Debug("calling AI for PR summary generation", "pr_number", prNumber) - summary, err := s.aiService.GeneratePRSummary(ctx, prompt) + var availableLabels []string + if s.vcsClient != nil { + var err error + availableLabels, err = s.vcsClient.GetRepoLabels(ctx) + if err != nil { + log.Warn("failed to fetch repo labels, proceeding without them", "error", err) + } + } + + summary, err := s.aiService.GeneratePRSummary(ctx, prompt, availableLabels) if err != nil { log.Error("failed to generate PR summary", "error", err, diff --git a/internal/services/pull_request_service_test.go b/internal/services/pull_request_service_test.go index ce1d8a8..4624feb 100644 --- a/internal/services/pull_request_service_test.go +++ b/internal/services/pull_request_service_test.go @@ -44,7 +44,8 @@ func TestPRService_SummarizePR_Success(t *testing.T) { mockVCS.On("GetPR", ctx, prNumber).Return(prData, nil) mockVCS.On("GetPRIssues", ctx, mock.Anything, mock.Anything, mock.Anything).Return([]models.Issue(nil), nil) - mockAI.On("GeneratePRSummary", ctx, mock.AnythingOfType("string")).Return(expectedSummary, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string{"enhancement"}, nil) + mockAI.On("GeneratePRSummary", ctx, mock.AnythingOfType("string"), []string{"enhancement"}).Return(expectedSummary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return strings.Contains(s.Body, expectedSummary.Body) && strings.Contains(s.Body, "Test Plan") })).Return(nil) @@ -114,7 +115,8 @@ func TestPRService_SummarizePR_GenerateError(t *testing.T) { mockVCS.On("GetPR", ctx, prNumber).Return(prData, nil) mockVCS.On("GetPRIssues", ctx, mock.Anything, mock.Anything, mock.Anything).Return([]models.Issue(nil), nil) - mockAI.On("GeneratePRSummary", ctx, mock.Anything).Return(models.PRSummary{}, expectedError) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) + mockAI.On("GeneratePRSummary", ctx, mock.Anything, []string(nil)).Return(models.PRSummary{}, expectedError) service := NewPRService( WithPRVCSClient(mockVCS), @@ -157,7 +159,8 @@ func TestPRService_SummarizePR_UpdateError(t *testing.T) { mockVCS.On("GetPR", ctx, prNumber).Return(prData, nil) mockVCS.On("GetPRIssues", ctx, mock.Anything, mock.Anything, mock.Anything).Return([]models.Issue(nil), nil) - mockAI.On("GeneratePRSummary", ctx, mock.Anything).Return(summary, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string{"bug"}, nil) + mockAI.On("GeneratePRSummary", ctx, mock.Anything, []string{"bug"}).Return(summary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return strings.Contains(s.Body, summary.Body) && strings.Contains(s.Body, "Test Plan") })).Return(expectedError) @@ -219,10 +222,11 @@ func TestPRService_SummarizePR_WithRelatedIssues(t *testing.T) { mockVCS.On("GetPR", ctx, prNumber).Return(prData, nil) mockVCS.On("GetPRIssues", ctx, prData.BranchName, []string{"Fix #789"}, prData.Description). Return(relatedIssues, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) mockAI.On("GeneratePRSummary", ctx, mock.MatchedBy(func(prompt string) bool { return contextContains(prompt, "Bug 1", "Bug 2", "Bug 3") - })).Return(expectedSummary, nil) + }), []string(nil)).Return(expectedSummary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return contextContains(s.Body, "Summary content", "Fixes #123", "Fixes #456", "Fixes #789", "Test Plan & Evidence") @@ -258,10 +262,11 @@ func TestPRService_SummarizePR_BreakingChanges(t *testing.T) { mockVCS.On("GetPR", ctx, prNumber).Return(prData, nil) mockVCS.On("GetPRIssues", ctx, mock.Anything, mock.Anything, mock.Anything).Return([]models.Issue(nil), nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) mockAI.On("GeneratePRSummary", ctx, mock.MatchedBy(func(prompt string) bool { return contextContains(prompt, "⚠️ Breaking Changes:", "feat!: breaking change here") - })).Return(expectedSummary, nil) + }), []string(nil)).Return(expectedSummary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return contextContains(s.Body, "Breaking Changes", "feat!: breaking change here", "Test Plan & Evidence") @@ -493,10 +498,11 @@ func TestPRService_SummarizePR_WithTemplate(t *testing.T) { BodyContent: "## Checklist\n- [ ] Done", } mockTemplate.On("GetPRTemplate", ctx, "PULL_REQUEST_TEMPLATE.md").Return(templateContent, nil) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) mockAI.On("GeneratePRSummary", ctx, mock.MatchedBy(func(prompt string) bool { return strings.Contains(prompt, "## Checklist") && strings.Contains(prompt, "- [ ] Done") - })).Return(expectedSummary, nil) + }), []string(nil)).Return(expectedSummary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return strings.Contains(s.Body, expectedSummary.Body) && strings.Contains(s.Body, "Test Plan") @@ -538,8 +544,9 @@ func TestPRService_SummarizePR_WithTemplateError(t *testing.T) { mockVCS.On("GetPRIssues", ctx, mock.Anything, mock.Anything, mock.Anything).Return([]models.Issue(nil), nil) mockTemplate.On("ListPRTemplates", ctx).Return([]models.TemplateMetadata(nil), errors.New("io error")) + mockVCS.On("GetRepoLabels", ctx).Return([]string(nil), nil) - mockAI.On("GeneratePRSummary", ctx, mock.Anything).Return(expectedSummary, nil) + mockAI.On("GeneratePRSummary", ctx, mock.Anything, []string(nil)).Return(expectedSummary, nil) mockVCS.On("UpdatePR", ctx, prNumber, mock.MatchedBy(func(s models.PRSummary) bool { return strings.Contains(s.Body, expectedSummary.Body) && strings.Contains(s.Body, "Test Plan") diff --git a/internal/vcs/github/client.go b/internal/vcs/github/client.go index 26324de..9dd9900 100644 --- a/internal/vcs/github/client.go +++ b/internal/vcs/github/client.go @@ -694,6 +694,13 @@ func (ghc *GitHubClient) CreateIssue(ctx context.Context, title string, body str "labels_count", len(labels), "assignees_count", len(assignees)) + if labels == nil { + labels = []string{} + } + if assignees == nil { + assignees = []string{} + } + issueRequest := &github.IssueRequest{ Title: github.Ptr(title), Body: github.Ptr(body),