diff --git a/cmd/store/import.go b/cmd/store/import.go index 16bdecc8..2a3d9198 100644 --- a/cmd/store/import.go +++ b/cmd/store/import.go @@ -44,6 +44,7 @@ const ( progressBarSleepDelay = 10 // time.Millisecond progressBarThrottleValue = 65 progressBarUpdateDelay = 5 * time.Millisecond + maxAssertionsPerWrite = 100 ) // createStore creates a new store with the given client configuration and store data. @@ -226,9 +227,15 @@ func importAssertions( StoreId: &storeID, } + if len(assertions) > maxAssertionsPerWrite { + fmt.Fprintf(os.Stderr, "Warning: %d test assertions found, but only the first %d will be written\n", + len(assertions), maxAssertionsPerWrite) + assertions = assertions[:maxAssertionsPerWrite] + } + _, err := fgaClient.WriteAssertions(ctx).Body(assertions).Options(writeOptions).Execute() if err != nil { - return fmt.Errorf("failed to import assertions: %w", err) + return fmt.Errorf("failed to import test assertions: %w", err) } } @@ -236,7 +243,15 @@ func importAssertions( } func getCheckAssertions(checkTests []storetest.ModelTestCheck) []client.ClientAssertion { - var assertions []client.ClientAssertion + totalAssertions := 0 + + for _, checkTest := range checkTests { + users := storetest.GetEffectiveUsers(checkTest) + objects := storetest.GetEffectiveObjects(checkTest) + totalAssertions += len(users) * len(objects) * len(checkTest.Assertions) + } + + assertions := make([]client.ClientAssertion, 0, totalAssertions) for _, checkTest := range checkTests { users := storetest.GetEffectiveUsers(checkTest) diff --git a/cmd/store/import_test.go b/cmd/store/import_test.go index edb0e031..c65b20af 100644 --- a/cmd/store/import_test.go +++ b/cmd/store/import_test.go @@ -12,6 +12,11 @@ import ( "github.com/openfga/cli/internal/storetest" ) +const ( + testModelID = "model-1" + testStoreID = "store-1" +) + func TestImportStore(t *testing.T) { t.Parallel() @@ -51,7 +56,7 @@ func TestImportStore(t *testing.T) { Expectation: true, }, } - modelID, storeID := "model-1", "store-1" + modelID, storeID := testModelID, testStoreID expectedOptions := client.ClientWriteAssertionsOptions{AuthorizationModelId: &modelID, StoreId: &storeID} importStoreTests := []struct { @@ -215,6 +220,64 @@ func TestImportStore(t *testing.T) { } } +func TestImportStoreWithTruncatedAssertions(t *testing.T) { + t.Parallel() + + modelID, storeID := testModelID, testStoreID + expectedOptions := client.ClientWriteAssertionsOptions{AuthorizationModelId: &modelID, StoreId: &storeID} + + // Generate 150 users to create 150 assertions (exceeding 100 limit) + users := make([]string, 150) + for i := range 150 { + users[i] = "user:" + string(rune('a'+i/26)) + string(rune('a'+i%26)) + } + + // Only the first 100 assertions should be written + first100Assertions := make([]client.ClientAssertion, 100) + for i := range 100 { + first100Assertions[i] = client.ClientAssertion{ + User: users[i], + Relation: "reader", + Object: "document:doc1", + Expectation: true, + } + } + + mockCtrl := gomock.NewController(t) + mockFgaClient := mockclient.NewMockSdkClient(mockCtrl) + + defer mockCtrl.Finish() + + // Only expect a single write with the first 100 assertions + setupWriteAssertionsMock(mockCtrl, mockFgaClient, first100Assertions, expectedOptions) + setupWriteModelMock(mockCtrl, mockFgaClient, modelID) + setupCreateStoreMock(mockCtrl, mockFgaClient, storeID) + + testStore := storetest.StoreData{ + Model: `type user + type document + relations + define reader: [user]`, + Tests: []storetest.ModelTest{ + { + Name: "Test", + Check: []storetest.ModelTestCheck{ + { + Users: users, + Object: "document:doc1", + Assertions: map[string]bool{"reader": true}, + }, + }, + }, + }, + } + + _, err := importStore(t.Context(), &fga.ClientConfig{}, mockFgaClient, &testStore, "", "", 10, 1, "") + if err != nil { + t.Errorf("expected no error, got %v", err) + } +} + func TestUpdateStore(t *testing.T) { t.Parallel() @@ -225,8 +288,8 @@ func TestUpdateStore(t *testing.T) { Expectation: true, }} - modelID := "model-1" - storeID := "store-1" + modelID := testModelID + storeID := testStoreID sampleTime := time.Now() expectedOptions := client.ClientWriteAssertionsOptions{ AuthorizationModelId: &modelID, diff --git a/internal/storetest/localtest.go b/internal/storetest/localtest.go index a297d896..17870108 100644 --- a/internal/storetest/localtest.go +++ b/internal/storetest/localtest.go @@ -26,10 +26,10 @@ func RunLocalCheckTest( tuples []client.ClientContextualTupleKey, options ModelTestOptions, ) []ModelTestCheckSingleResult { - results := []ModelTestCheckSingleResult{} users := GetEffectiveUsers(checkTest) - objects := GetEffectiveObjects(checkTest) + results := make([]ModelTestCheckSingleResult, 0, len(users)*len(objects)*len(checkTest.Assertions)) + for _, user := range users { for _, object := range objects { for relation, expectation := range checkTest.Assertions { @@ -101,7 +101,7 @@ func RunLocalListObjectsTest( tuples []client.ClientContextualTupleKey, options ModelTestOptions, ) []ModelTestListObjectsSingleResult { - results := []ModelTestListObjectsSingleResult{} + results := make([]ModelTestListObjectsSingleResult, 0, len(listObjectsTest.Assertions)) for relation, expectation := range listObjectsTest.Assertions { result := ModelTestListObjectsSingleResult{ @@ -168,8 +168,7 @@ func RunLocalListUsersTest( tuples []client.ClientContextualTupleKey, options ModelTestOptions, ) []ModelTestListUsersSingleResult { - results := []ModelTestListUsersSingleResult{} - + results := make([]ModelTestListUsersSingleResult, 0, len(listUsersTest.Assertions)) object, pbObject := convertStoreObjectToObject(listUsersTest.Object) userFilter := &pb.UserTypeFilter{ diff --git a/internal/storetest/remotetest.go b/internal/storetest/remotetest.go index ebebcf62..9112a444 100644 --- a/internal/storetest/remotetest.go +++ b/internal/storetest/remotetest.go @@ -33,10 +33,9 @@ func RunRemoteCheckTest( checkTest ModelTestCheck, tuples []client.ClientContextualTupleKey, ) []ModelTestCheckSingleResult { - results := []ModelTestCheckSingleResult{} - users := GetEffectiveUsers(checkTest) objects := GetEffectiveObjects(checkTest) + results := make([]ModelTestCheckSingleResult, 0, len(users)*len(objects)*len(checkTest.Assertions)) for _, user := range users { for _, object := range objects { @@ -89,7 +88,7 @@ func RunRemoteListObjectsTest( listObjectsTest ModelTestListObjects, tuples []client.ClientContextualTupleKey, ) []ModelTestListObjectsSingleResult { - results := []ModelTestListObjectsSingleResult{} + results := make([]ModelTestListObjectsSingleResult, 0, len(listObjectsTest.Assertions)) for relation, expectation := range listObjectsTest.Assertions { result := RunSingleRemoteListObjectsTest(ctx, fgaClient, @@ -138,9 +137,9 @@ func RunRemoteListUsersTest( listUsersTest ModelTestListUsers, tuples []client.ClientContextualTupleKey, ) []ModelTestListUsersSingleResult { - results := []ModelTestListUsersSingleResult{} - + results := make([]ModelTestListUsersSingleResult, 0, len(listUsersTest.Assertions)) object, _ := convertStoreObjectToObject(listUsersTest.Object) + for relation, expectation := range listUsersTest.Assertions { result := RunSingleRemoteListUsersTest(ctx, fgaClient, client.ClientListUsersRequest{ @@ -165,21 +164,21 @@ func RunRemoteTest( test ModelTest, testTuples []client.ClientContextualTupleKey, ) TestResult { - checkResults := []ModelTestCheckSingleResult{} + checkResults := make([]ModelTestCheckSingleResult, 0, len(test.Check)) for index := range test.Check { results := RunRemoteCheckTest(ctx, fgaClient, test.Check[index], testTuples) checkResults = append(checkResults, results...) } - listObjectResults := []ModelTestListObjectsSingleResult{} + listObjectResults := make([]ModelTestListObjectsSingleResult, 0, len(test.ListObjects)) for index := range test.ListObjects { results := RunRemoteListObjectsTest(ctx, fgaClient, test.ListObjects[index], testTuples) listObjectResults = append(listObjectResults, results...) } - listUserResults := []ModelTestListUsersSingleResult{} + listUserResults := make([]ModelTestListUsersSingleResult, 0, len(test.ListUsers)) for index := range test.ListUsers { results := RunRemoteListUsersTest(ctx, fgaClient, test.ListUsers[index], testTuples) diff --git a/internal/storetest/testresult.go b/internal/storetest/testresult.go index 266930de..9b921400 100644 --- a/internal/storetest/testresult.go +++ b/internal/storetest/testresult.go @@ -366,12 +366,12 @@ func (test TestResults) FriendlyDisplay() string { //nolint:cyclop func (test TestResults) FriendlyBody() string { fullOutput := test.FriendlyDisplay() - headerIndex := strings.Index(fullOutput, "# Test Summary #") - if headerIndex == -1 { + before, _, ok := strings.Cut(fullOutput, "# Test Summary #") + if !ok { return fullOutput } - return strings.TrimSpace(fullOutput[:headerIndex]) + return strings.TrimSpace(before) } func buildTestSummary(failedTestCount int, summary string, totalTestCount int,