Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 36 additions & 48 deletions pkg/groupmapper/groupmapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@ package groupmapper
import (
"context"
"fmt"
"time"
"slices"

"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
kuser "k8s.io/apiserver/pkg/authentication/user"

userv1 "github.com/openshift/api/user/v1"
userclient "github.com/openshift/client-go/user/clientset/versioned/typed/user/v1"
userinformer "github.com/openshift/client-go/user/informers/externalversions/user/v1"
userlisterv1 "github.com/openshift/client-go/user/listers/user/v1"
usercache "github.com/openshift/library-go/pkg/oauth/usercache"

authapi "github.com/openshift/oauth-server/pkg/api"
)
Expand Down Expand Up @@ -58,18 +55,12 @@ func (w *UserInfoGroupsWrapper) GetGroups() []string {
type UserGroupsMapper struct {
delegatedUserMapper authapi.UserIdentityMapper
groupsClient userclient.GroupInterface
groupsLister userlisterv1.GroupLister
groupsCache *usercache.GroupCache
groupsSynced func() bool
}

func NewUserGroupsMapper(delegate authapi.UserIdentityMapper, groupInformer userinformer.GroupInformer, groupsClient userclient.GroupInterface, groupsLister userlisterv1.GroupLister) *UserGroupsMapper {
return &UserGroupsMapper{
delegatedUserMapper: delegate,
groupsClient: groupsClient,
groupsLister: groupsLister,
groupsCache: usercache.NewGroupCache(groupInformer),
groupsSynced: groupInformer.Informer().HasSynced,
}
}

Expand All @@ -90,60 +81,61 @@ func (m *UserGroupsMapper) UserFor(identityInfo authapi.UserIdentityInfo) (kuser
}, nil
}

func (m *UserGroupsMapper) processGroups(idpName, username string, groups sets.String) error {
err := wait.PollImmediate(1*time.Second, 5*time.Second, func() (bool, error) {
return m.groupsSynced(), nil
})
// processGroups synchronizes the user's group memberships with the identity provider.
// NOTE: This makes a direct API call to list all groups on every login to ensure
// correctness and avoid cache staleness issues (see OCPBUGS-63228).
func (m *UserGroupsMapper) processGroups(idpName, username string, providerGroups sets.String) error {
ctx := context.Background()
clusterGroupsList, err := m.groupsClient.List(ctx, metav1.ListOptions{})
if err != nil {
return err
return fmt.Errorf("could not list cluster groups: %v", err)
}

cachedGroups, err := m.groupsCache.GroupsFor(username)
if err != nil {
return err
clusterGroups := map[string]*userv1.Group{}
for _, g := range clusterGroupsList.Items {
clusterGroups[g.Name] = &g
}

userGroupsForIDP := make([]*userv1.Group, 0)
for _, g := range clusterGroupsList.Items {
if g.Annotations[fmt.Sprintf(groupSyncedKeyFmt, idpName)] == "synced" && slices.Contains(g.Users, username) {
userGroupsForIDP = append(userGroupsForIDP, &g)
}
}

removeGroups, addGroups := groupsDiff(cachedGroups, groups)
removeGroups, addGroups := groupsDiff(userGroupsForIDP, providerGroups)
for _, g := range removeGroups {
if err := m.removeUserFromGroup(idpName, username, g); err != nil {
if err := m.removeUserFromGroup(ctx, idpName, username, clusterGroups[g]); err != nil {
return err
}
}

for _, g := range addGroups {
if err := m.addUserToGroup(idpName, username, g); err != nil {
if err := m.addUserToGroup(ctx, idpName, username, g, clusterGroups[g]); err != nil {
return err
}
}

return nil
}

func (m *UserGroupsMapper) removeUserFromGroup(idpName, username, group string) error {
updatedGroup, err := m.groupsLister.Get(group)
if err != nil {
if errors.IsNotFound(err) {
return nil
}
return err
}

if len(updatedGroup.Users) == 0 {
func (m *UserGroupsMapper) removeUserFromGroup(ctx context.Context, idpName, username string, group *userv1.Group) error {
if group == nil || len(group.Users) == 0 {
return nil
}

if len(updatedGroup.Users) == 1 && updatedGroup.Users[0] == username && updatedGroup.Annotations[groupGeneratedKey] == "true" {
return m.groupsClient.Delete(context.TODO(), group, metav1.DeleteOptions{})
if len(group.Users) == 1 && group.Users[0] == username && group.Annotations[groupGeneratedKey] == "true" {
return m.groupsClient.Delete(ctx, group.Name, metav1.DeleteOptions{})
}

// don't perform any actions on the group if it hasn't been synced for this IdP
if updatedGroup.Annotations[fmt.Sprintf(groupSyncedKeyFmt, idpName)] != "synced" {
if group.Annotations[fmt.Sprintf(groupSyncedKeyFmt, idpName)] != "synced" {
return nil
}

// find the user and remove it from the slice
userIdx := -1
for i, groupUser := range updatedGroup.Users {
for i, groupUser := range group.Users {
if groupUser == username {
userIdx = i
break
Expand All @@ -155,25 +147,24 @@ func (m *UserGroupsMapper) removeUserFromGroup(idpName, username, group string)
case -1:
return nil
case 0:
newUsers = updatedGroup.Users[1:]
newUsers = group.Users[1:]
default:
newUsers = append(updatedGroup.Users[0:userIdx], updatedGroup.Users[userIdx+1:]...)
newUsers = append(group.Users[0:userIdx], group.Users[userIdx+1:]...)
}

updatedGroupCopy := updatedGroup.DeepCopy()
updatedGroupCopy := group.DeepCopy()
updatedGroupCopy.Users = newUsers

_, err = m.groupsClient.Update(context.TODO(), updatedGroupCopy, metav1.UpdateOptions{})
_, err := m.groupsClient.Update(ctx, updatedGroupCopy, metav1.UpdateOptions{})
return err
}

func (m *UserGroupsMapper) addUserToGroup(idpName, username, group string) error {
updatedGroup, err := m.groupsLister.Get(group)
if errors.IsNotFound(err) {
_, err = m.groupsClient.Create(context.TODO(),
func (m *UserGroupsMapper) addUserToGroup(ctx context.Context, idpName, username, groupName string, updatedGroup *userv1.Group) error {
if updatedGroup == nil {
_, err := m.groupsClient.Create(ctx,
&userv1.Group{
ObjectMeta: metav1.ObjectMeta{
Name: group,
Name: groupName,
Annotations: map[string]string{
fmt.Sprintf(groupSyncedKeyFmt, idpName): "synced",
groupGeneratedKey: "true",
Expand All @@ -185,9 +176,6 @@ func (m *UserGroupsMapper) addUserToGroup(idpName, username, group string) error
)
return err
}
if err != nil {
return err
}

if updatedGroup.Annotations == nil {
updatedGroup.Annotations = map[string]string{}
Expand All @@ -210,7 +198,7 @@ func (m *UserGroupsMapper) addUserToGroup(idpName, username, group string) error
}
updatedGroupCopy.Annotations[fmt.Sprintf(groupSyncedKeyFmt, idpName)] = "synced"

_, err = m.groupsClient.Update(context.TODO(), updatedGroupCopy, metav1.UpdateOptions{})
_, err := m.groupsClient.Update(ctx, updatedGroupCopy, metav1.UpdateOptions{})
return err
}

Expand Down
31 changes: 5 additions & 26 deletions pkg/groupmapper/groupmapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@ import (
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/watch"
kuser "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/client-go/tools/cache"

userv1 "github.com/openshift/api/user/v1"
fakeuserclient "github.com/openshift/client-go/user/clientset/versioned/fake"
userinformer "github.com/openshift/client-go/user/informers/externalversions"
userlisterv1 "github.com/openshift/client-go/user/listers/user/v1"
usercache "github.com/openshift/library-go/pkg/oauth/usercache"

authapi "github.com/openshift/oauth-server/pkg/api"
)
Expand Down Expand Up @@ -78,27 +74,14 @@ func TestUserGroupsMapper_UserFor(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {

groupObjs := []runtime.Object{}
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{})
for _, g := range basicGroups {
groupObjs = append(groupObjs, g)
require.NoError(t, indexer.Add(g))
}
fakeGroupsClient := fakeuserclient.NewSimpleClientset(groupObjs...)

userInformer := userinformer.NewSharedInformerFactory(fakeGroupsClient, 5*time.Second)
require.NoError(t, userInformer.User().V1().Groups().Informer().AddIndexers(cache.Indexers{
usercache.ByUserIndexName: usercache.ByUserIndexKeys,
}))
testCtx, cancelCtx := context.WithCancel(context.Background())
go userInformer.Start(testCtx.Done())
defer cancelCtx()

m := &UserGroupsMapper{
delegatedUserMapper: &mockUserMapper{userInfo: kuser.DefaultInfo{Name: tt.username, UID: "tehUserUID", Groups: []string{"system:one", "system:two"}}},
groupsClient: fakeGroupsClient.UserV1().Groups(),
groupsLister: userlisterv1.NewGroupLister(indexer),
groupsCache: usercache.NewGroupCache(userInformer.User().V1().Groups()),
groupsSynced: userInformer.User().V1().Groups().Informer().HasSynced,
}

identityInfo := &authapi.DefaultUserIdentityInfo{ProviderName: testIDPName, ProviderUserName: tt.username, ProviderGroups: tt.idpGroups}
Expand Down Expand Up @@ -130,10 +113,12 @@ func TestUserGroupsMapper_UserFor(t *testing.T) {
require.NoError(t, err)
for _, g := range groups.Items {
assertion := require.False
assertionStr := "require user '%s' not present in group '%s'"
if userGroups.Has(g.Name) {
assertion = require.True
assertionStr = "require user '%s' present in group '%s'"
}
assertion(t, userPresent(tt.username, g.Users))
assertion(t, userPresent(tt.username, g.Users), fmt.Sprintf(assertionStr, tt.username, g.Name))
userGroups.Delete(g.Name)
}
require.True(t, userGroups.Len() == 0)
Expand Down Expand Up @@ -223,11 +208,9 @@ func TestUserGroupsMapper_removeUserFromGroup(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{})
groups := []runtime.Object{}
if tt.group != nil {
groups = append(groups, tt.group)
require.NoError(t, indexer.Add(tt.group))
}
fakeUserClient := fakeuserclient.NewSimpleClientset(groups...)
testCtx := context.Background()
Expand All @@ -236,7 +219,6 @@ func TestUserGroupsMapper_removeUserFromGroup(t *testing.T) {
defer groupWatcher.Stop()

m := &UserGroupsMapper{
groupsLister: userlisterv1.NewGroupLister(indexer),
groupsClient: fakeUserClient.UserV1().Groups(),
}

Expand All @@ -250,7 +232,7 @@ func TestUserGroupsMapper_removeUserFromGroup(t *testing.T) {
go watchForGroupEvents(groupWatcher, tt.expectedGroup, tt.expectEvent, expectedEventType, failed, finished, timedCtx)

go func() {
if err := m.removeUserFromGroup(testIDPName, tt.username, testGroupName); (err != nil) != tt.wantErr {
if err := m.removeUserFromGroup(t.Context(), testIDPName, tt.username, tt.group); (err != nil) != tt.wantErr {
t.Errorf("UserGroupsMapper.removeUserFromGroup() error = %v, wantErr %v", err, tt.wantErr)
}

Expand Down Expand Up @@ -324,11 +306,9 @@ func TestUserGroupsMapper_addUserToGroup(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
indexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{})
groups := []runtime.Object{}
if tt.group != nil {
groups = append(groups, tt.group)
require.NoError(t, indexer.Add(tt.group))
}
fakeUserClient := fakeuserclient.NewSimpleClientset(groups...)
testCtx := context.Background()
Expand All @@ -337,7 +317,6 @@ func TestUserGroupsMapper_addUserToGroup(t *testing.T) {
defer groupWatcher.Stop()

m := &UserGroupsMapper{
groupsLister: userlisterv1.NewGroupLister(indexer),
groupsClient: fakeUserClient.UserV1().Groups(),
}

Expand All @@ -351,7 +330,7 @@ func TestUserGroupsMapper_addUserToGroup(t *testing.T) {
go watchForGroupEvents(groupWatcher, tt.expectedGroup, tt.expectEvent, expectedEventType, failed, finished, timedCtx)

go func() {
if err := m.addUserToGroup(testIDPName, tt.username, testGroupName); (err != nil) != tt.wantErr {
if err := m.addUserToGroup(t.Context(), testIDPName, tt.username, testGroupName, tt.group); (err != nil) != tt.wantErr {
t.Errorf("UserGroupsMapper.addUserToGroup() error = %v, wantErr %v", err, tt.wantErr)
}

Expand Down