Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 108 additions & 79 deletions group/cache.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2023 CERN
// Copyright 2018-2026 CERN
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,15 +19,17 @@
package rest

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"

cboxutils "github.com/cernbox/reva-plugins/utils"
grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1"
userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
"github.com/cs3org/reva/v3/pkg/appctx"
"github.com/gomodule/redigo/redis"
)

Expand All @@ -40,45 +42,24 @@ const (
groupInternalIDPrefix = "internal:"
)

func initRedisPool(address, username, password string) *redis.Pool {
return &redis.Pool{

MaxIdle: 50,
MaxActive: 1000,
IdleTimeout: 240 * time.Second,

Dial: func() (redis.Conn, error) {
var c redis.Conn
var err error
switch {
case username != "":
c, err = redis.Dial("tcp", address,
redis.DialUsername(username),
redis.DialPassword(password),
)
case password != "":
c, err = redis.Dial("tcp", address,
redis.DialPassword(password),
)
default:
c, err = redis.Dial("tcp", address)
}

if err != nil {
return nil, err
}
return c, err
},

TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
type redisPools struct {
read *redis.Pool
write *redis.Pool
}

func initRedisPools(ctx context.Context, address, sentinelAddress, username, password string, sentinelMode bool, masterName string) *redisPools {
log := appctx.GetLogger(ctx)
p, err := cboxutils.NewRedisPoolsWithSentinelAddress(ctx, address, sentinelAddress, username, password, sentinelMode, masterName)
if err != nil {
log.Error().Err(err).Msg("rest: failed to initialize redis pools")
return &redisPools{}
}
log.Debug().Msg("rest: successfully initialized redis pools")
return &redisPools{read: p.Read, write: p.Write}
}

func (m *manager) setVal(key, val string, expiration int) error {
conn := m.redisPool.Get()
conn := m.redisPoolsField.write.Get()
defer conn.Close()
if conn != nil {
args := []interface{}{key, val}
Expand All @@ -93,63 +74,111 @@ func (m *manager) setVal(key, val string, expiration int) error {
return errors.New("rest: unable to get connection from redis pool")
}

func (m *manager) getVal(key string) (string, error) {
conn := m.redisPool.Get()
defer conn.Close()
if conn != nil {
func (m *manager) getVal(ctx context.Context, key string) (string, error) {
log := appctx.GetLogger(ctx)

// Try read pool first
conn := m.redisPoolsField.read.Get()
if conn != nil && conn.Err() == nil {
val, err := redis.String(conn.Do("GET", key))
if err != nil {
return "", err
conn.Close() // release back to read pool

// If it succeeded or simply wasn't found, we're done
if err == nil || err == redis.ErrNil {
return val, err
}
return val, nil
log.Debug().Err(err).Msg("rest: read pool GET failed, falling back to write pool")
} else if conn != nil {
log.Debug().Err(conn.Err()).Msg("rest: read pool connection failed, falling back to write pool")
conn.Close() // close broken read connection
} else {
log.Debug().Msg("rest: read pool provided nil connection, falling back to write pool")
}
return "", errors.New("rest: unable to get connection from redis pool")
}

func (m *manager) findCachedGroups(query string) ([]*grouppb.Group, error) {
conn := m.redisPool.Get()
// Fallback: try write pool
conn = m.redisPoolsField.write.Get()
if conn == nil {
return "", errors.New("rest: unable to get connection from redis pool")
}
defer conn.Close()
if conn != nil {
query = fmt.Sprintf("%s*%s*", groupPrefix, strings.ReplaceAll(strings.ToLower(query), " ", "_"))
keys, err := redis.Strings(conn.Do("KEYS", query))

val, err := redis.String(conn.Do("GET", key))
if err != nil {
return "", err
}
log.Debug().Any("key", key).Msg("rest: successfully got value from write pool")
return val, nil
}

func (m *manager) findCachedGroups(ctx context.Context, query string) ([]*grouppb.Group, error) {
query = fmt.Sprintf("%s*%s*", groupPrefix, strings.ReplaceAll(strings.ToLower(query), " ", "_"))
var keys []string
var err error
log := appctx.GetLogger(ctx)
// Try read pool first
conn := m.redisPoolsField.read.Get()
if conn != nil && conn.Err() == nil {
keys, err = redis.Strings(conn.Do("KEYS", query))
if err != nil {
return nil, err
}
var args []interface{}
for _, k := range keys {
args = append(args, k)
log.Debug().Err(err).Msg("rest: read pool KEYS failed, falling back to write pool")
conn.Close() // close broken read connection on failure
conn = nil
}
} else if conn != nil {
log.Debug().Err(conn.Err()).Msg("rest: read pool connection failed, falling back to write pool")
conn.Close() // close broken read connection
conn = nil
} else {
log.Debug().Msg("rest: read pool provided nil connection, falling back to write pool")
}

if len(args) == 0 {
return []*grouppb.Group{}, nil
// Fallback: try write pool if read pool failed
if conn == nil {
conn = m.redisPoolsField.write.Get()
if conn == nil {
return nil, errors.New("rest: unable to get connection from redis pool")
}

// Fetch the groups for all these keys
groupStrings, err := redis.Strings(conn.Do("MGET", args...))
keys, err = redis.Strings(conn.Do("KEYS", query))
if err != nil {
conn.Close()
return nil, err
}
groupMap := make(map[string]*grouppb.Group)
for _, group := range groupStrings {
g := grouppb.Group{}
if err = json.Unmarshal([]byte(group), &g); err == nil {
groupMap[g.Id.OpaqueId] = &g
}
}
}
defer conn.Close()

var groups []*grouppb.Group
for _, g := range groupMap {
groups = append(groups, g)
var args []interface{}
for _, k := range keys {
args = append(args, k)
}

if len(args) == 0 {
return []*grouppb.Group{}, nil
}

// Fetch the groups for all these keys
groupStrings, err := redis.Strings(conn.Do("MGET", args...))
if err != nil {
return nil, err
}
groupMap := make(map[string]*grouppb.Group)
for _, group := range groupStrings {
g := grouppb.Group{}
if err = json.Unmarshal([]byte(group), &g); err == nil {
groupMap[g.Id.OpaqueId] = &g
}
}

return groups, nil
var groups []*grouppb.Group
for _, g := range groupMap {
groups = append(groups, g)
}

return nil, errors.New("rest: unable to get connection from redis pool")
log.Debug().Any("query", query).Int("results", len(groups)).Msg("rest: successfully found cached groups")
return groups, nil
}

func (m *manager) fetchCachedGroupDetails(gid *grouppb.GroupId) (*grouppb.Group, error) {
group, err := m.getVal(groupPrefix + idPrefix + gid.OpaqueId)
func (m *manager) fetchCachedGroupDetails(ctx context.Context, gid *grouppb.GroupId) (*grouppb.Group, error) {
group, err := m.getVal(ctx, groupPrefix+idPrefix+gid.OpaqueId)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -183,8 +212,8 @@ func (m *manager) cacheGroupDetails(g *grouppb.Group) error {
return nil
}

func (m *manager) fetchCachedGroupByParam(field, claim string) (*grouppb.Group, error) {
group, err := m.getVal(groupPrefix + field + ":" + strings.ToLower(claim))
func (m *manager) fetchCachedGroupByParam(ctx context.Context, field, claim string) (*grouppb.Group, error) {
group, err := m.getVal(ctx, groupPrefix+field+":"+strings.ToLower(claim))
if err != nil {
return nil, err
}
Expand All @@ -196,8 +225,8 @@ func (m *manager) fetchCachedGroupByParam(field, claim string) (*grouppb.Group,
return &g, nil
}

func (m *manager) fetchCachedGroupMembers(gid *grouppb.GroupId) ([]*userpb.UserId, error) {
members, err := m.getVal(groupPrefix + groupMembersPrefix + strings.ToLower(gid.OpaqueId))
func (m *manager) fetchCachedGroupMembers(ctx context.Context, gid *grouppb.GroupId) ([]*userpb.UserId, error) {
members, err := m.getVal(ctx, groupPrefix+groupMembersPrefix+strings.ToLower(gid.OpaqueId))
if err != nil {
return nil, err
}
Expand Down
27 changes: 18 additions & 9 deletions group/rest.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2023 CERN
// Copyright 2018-2026 CERN
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,7 +36,6 @@ import (
"github.com/cs3org/reva/v3/pkg/group"
"github.com/cs3org/reva/v3/pkg/utils/cfg"
"github.com/cs3org/reva/v3/pkg/utils/list"
"github.com/gomodule/redigo/redis"
"github.com/rs/zerolog/log"
)

Expand All @@ -46,7 +45,7 @@ func init() {

type manager struct {
conf *config
redisPool *redis.Pool
redisPoolsField *redisPools
apiTokenManager *utils.APITokenManager
}

Expand All @@ -60,10 +59,16 @@ func (manager) RevaPlugin() reva.PluginInfo {
type config struct {
// The address at which the redis server is running
RedisAddress string `mapstructure:"redis_address" docs:"localhost:6379"`
// The address at which the redis sentinel is running (only used when redis_sentinel_mode is enabled)
RedisSentinelAddress string `mapstructure:"redis_sentinel_address" docs:""`
// The username for connecting to the redis server
RedisUsername string `mapstructure:"redis_username" docs:""`
// The password for connecting to the redis server
RedisPassword string `mapstructure:"redis_password" docs:""`
// The name of the master node in case Redis Sentinel mode is enabled
RedisMasterName string `mapstructure:"redis_master_name" docs:""`
// Whether to use Redis Sentinel mode
RedisSentinelMode bool `mapstructure:"redis_sentinel_mode" docs:""`
// The time in minutes for which the members of a group would be cached
GroupMembersCacheExpiration int `mapstructure:"group_members_cache_expiration" docs:"5"`
// The OIDC Provider
Expand All @@ -90,6 +95,9 @@ func (c *config) ApplyDefaults() {
if c.RedisAddress == "" {
c.RedisAddress = ":6379"
}
if c.RedisSentinelAddress == "" {
c.RedisSentinelAddress = c.RedisAddress
}
if c.APIBaseURL == "" {
c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch"
}
Expand All @@ -113,16 +121,17 @@ func New(ctx context.Context, m map[string]interface{}) (group.Manager, error) {
if err := cfg.Decode(m, &c); err != nil {
return nil, err
}
c.ApplyDefaults()

redisPool := initRedisPool(c.RedisAddress, c.RedisUsername, c.RedisPassword)
pools := initRedisPools(ctx, c.RedisAddress, c.RedisSentinelAddress, c.RedisUsername, c.RedisPassword, c.RedisSentinelMode, c.RedisMasterName)
apiTokenManager, err := utils.InitAPITokenManager(m)
if err != nil {
return nil, err
}

mgr := &manager{
conf: &c,
redisPool: redisPool,
redisPoolsField: pools,
apiTokenManager: apiTokenManager,
}
go mgr.fetchAllGroups(context.Background())
Expand Down Expand Up @@ -212,7 +221,7 @@ func (m *manager) parseAndCacheGroup(ctx context.Context, g *Group) (*grouppb.Gr
}

func (m *manager) GetGroup(ctx context.Context, gid *grouppb.GroupId, skipFetchingMembers bool) (*grouppb.Group, error) {
g, err := m.fetchCachedGroupDetails(gid)
g, err := m.fetchCachedGroupDetails(ctx, gid)
if err != nil {
return nil, err
}
Expand All @@ -233,7 +242,7 @@ func (m *manager) GetGroupByClaim(ctx context.Context, claim, value string, skip
return m.GetGroup(ctx, &grouppb.GroupId{OpaqueId: value}, skipFetchingMembers)
}

g, err := m.fetchCachedGroupByParam(claim, value)
g, err := m.fetchCachedGroupByParam(ctx, claim, value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -264,11 +273,11 @@ func (m *manager) FindGroups(ctx context.Context, query string, skipFetchingMemb
}
}

return m.findCachedGroups(query)
return m.findCachedGroups(ctx, query)
}

func (m *manager) GetMembers(ctx context.Context, gid *grouppb.GroupId) ([]*userpb.UserId, error) {
users, err := m.fetchCachedGroupMembers(gid)
users, err := m.fetchCachedGroupMembers(ctx, gid)
if err == nil {
return users, nil
}
Expand Down
Loading
Loading