Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ func main() {

sort.Sort(tasks)
hashToTask := make(map[string]pkg.Task)
operationAliasToID := make(map[string]string)
var buf bytes.Buffer
for _, task := range tasks {
buf.Write([]byte(task.IDWithHostPort()))
hashToTask[pkg.Hash([]byte(task.ID()))] = task
hashToTask[task.Hash()] = task
if task.OperationAlias() != "" {
operationAliasToID[task.OperationAlias()] = task.OperationID()
}
}

newVersion := pkg.Hash(buf.Bytes())
Expand All @@ -108,7 +112,7 @@ func main() {
logger.Infof("%d tasks discovered:\n%s", len(tasks), tasks)
version = newVersion

err = taskUpdater.Update(ctx, hashToTask, version)
err = taskUpdater.Update(ctx, hashToTask, operationAliasToID, version)
if err != nil {
logger.Errorf("failed to update tasks: %v", err)
version = "" // drop version so we will retry update on next iteration
Expand Down
65 changes: 39 additions & 26 deletions server/pkg/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pkg

import (
"context"
"fmt"
"net/http"
"strings"
"sync"
Expand All @@ -18,12 +19,13 @@ import (
type authServer struct {
authv3.UnimplementedAuthorizationServer

mx sync.RWMutex
hashToTasks map[string]Task
yt ytsdk.Client
ytProxy string
logger *SimpleLogger
authCookieName string
mx sync.RWMutex
hashToTasks map[string]Task
operationAliasToID map[string]string
yt ytsdk.Client
ytProxy string
logger *SimpleLogger
authCookieName string
}

func CreateAuthServer(yt ytsdk.Client, ytProxy string, logger *SimpleLogger, authCookieName string) *authServer {
Expand All @@ -41,22 +43,11 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth
httpAttrs := req.GetAttributes().GetRequest().GetHttp()
path := httpAttrs.GetPath()
headers := httpAttrs.GetHeaders()
host := httpAttrs.GetHost()

var hash string
if routerHeaderValue, ok := httpAttrs.Headers[routerHeaderName]; ok {
hash = routerHeaderValue
} else if host := httpAttrs.Host; host != "" {
hash = strings.Split(host, ".")[0]
} else {
s.logger.Warnf("authority (host) or %s headers are missing in request", routerHeaderName)
return deniedResponse, nil
}

s.logger.Debugf("checking auth for hash %q, path %q", hash, path)

task, ok := s.getHashToTasks()[hash]
if !ok {
s.logger.Warnf("no entry for hash %q in tasks registry", hash)
task, err := s.findTaskByRequest(host, headers)
if err != nil {
s.logger.Warnf("failed to find task during auth check: %s", err)
return deniedResponse, nil
}

Expand All @@ -66,7 +57,7 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth
return okResponse, nil
}

s.logger.Debugf("auth for hash %q, path %q, task %v", hash, path, task)
s.logger.Debugf("auth for path %q, task %v", path, task)

allowed, err := s.checkOperationPermission(ctx, task.operationID, headers)
if err != nil {
Expand All @@ -80,21 +71,43 @@ func (s *authServer) Check(ctx context.Context, req *authv3.CheckRequest) (*auth
return okResponse, nil
}

func (s *authServer) SetHashToTasks(hashToTasks map[string]Task) {
func (s *authServer) SetTasksData(hashToTasks map[string]Task, operationAliasToID map[string]string) {
s.mx.Lock()
defer s.mx.Unlock()

s.hashToTasks = hashToTasks
s.operationAliasToID = operationAliasToID
}

func (s *authServer) getHashToTasks() map[string]Task {
func (s *authServer) findTaskByRequest(host string, headers map[string]string) (*Task, error) {
s.mx.RLock()
defer s.mx.RUnlock()

return s.hashToTasks
var hash string
if routerHeaderValue, ok := headers[routerHeaderName]; ok {
hash = routerHeaderValue
} else if host != "" {
subdomain := strings.Split(host, ".")[0]
if operationAlias, taskName, service, ok := tryParseAliasSubdomain(subdomain); ok {
operationID, ok := s.operationAliasToID[operationAlias]
if !ok {
return nil, fmt.Errorf("operation by alias %q from subdomain was not found", operationAlias)
}
hash = (&Task{operationID: operationID, taskName: taskName, service: service}).Hash()
} else {
hash = subdomain
}
} else {
return nil, fmt.Errorf("authority (host) or %s headers are missing in request", routerHeaderName)
}

if task, ok := s.hashToTasks[hash]; !ok {
return nil, fmt.Errorf("no entry for hash %q in tasks registry", hash)
} else {
return &task, nil
}
}

// TODO: temporary implementation, use YT Go SDK instead
func (s *authServer) checkOperationPermission(ctx context.Context, operationID string, headers map[string]string) (bool, error) {
userCredentials := s.getYTCredentialsFromHeaders(headers)
if userCredentials == nil {
Expand Down
164 changes: 164 additions & 0 deletions server/pkg/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package pkg

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestFindTaskByRequest(t *testing.T) {
// Setup test data
task1 := Task{
operationID: "op-123",
taskName: "worker",
service: "api",
}
task1Hash := task1.Hash()

task2 := Task{
operationID: "op-456",
operationAlias: "myalias",
taskName: "master",
service: "ui",
}
task2Hash := task2.Hash()

task3 := Task{
operationID: "op-789",
taskName: "executor",
service: "grpc",
}
task3Hash := task3.Hash()

hashToTasks := map[string]Task{
task1Hash: task1,
task2Hash: task2,
task3Hash: task3,
}

operationAliasToID := map[string]string{
"myalias": "op-456",
"anotheralias": "op-999",
}

server := CreateAuthServer(nil, "", &SimpleLogger{}, "")
server.SetTasksData(hashToTasks, operationAliasToID)

tests := []struct {
name string
host string
headers map[string]string
expectedID string
errorMsg string
}{
// Source 1: Direct hash from x-yt-taskproxy-id header
{
name: "hash from header - valid task",
host: "ignored.example.com",
headers: map[string]string{
"x-yt-taskproxy-id": task1Hash,
},
expectedID: task1.operationID,
},
{
name: "hash from header - invalid hash",
host: "ignored.example.com",
headers: map[string]string{
"x-yt-taskproxy-id": "nonexistent",
},
errorMsg: "no entry for hash \"nonexistent\" in tasks registry",
},
{
name: "hash from header - empty hash",
host: "ignored.example.com",
headers: map[string]string{
"x-yt-taskproxy-id": "",
},
errorMsg: "no entry for hash \"\" in tasks registry",
},
{
name: "hash from header - header takes precedence over host",
host: task3Hash + ".example.com",
headers: map[string]string{
"x-yt-taskproxy-id": task1Hash,
},
expectedID: task1.operationID,
},

// Source 2: Alias-based subdomain (format: alias-taskname-service)
{
name: "alias subdomain - valid alias",
host: "myalias-master-ui.example.com",
headers: map[string]string{
"other-header": "value",
},
expectedID: task2.operationID,
},
{
name: "alias subdomain - unknown alias",
host: "unknownalias-master-ui.example.com",
errorMsg: "operation by alias \"unknownalias\" from subdomain was not found",
},
{
name: "alias subdomain - valid alias but task not found",
host: "anotheralias-worker-api.example.com",
errorMsg: "no entry for hash",
},
{
name: "alias subdomain - with port",
host: "myalias-master-ui.example.com:8080",
expectedID: task2.operationID,
},

// Source 3: Direct hash from subdomain (fallback)
{
name: "direct hash subdomain - valid hash",
host: task1Hash + ".example.com",
expectedID: task1.operationID,
},
{
name: "direct hash subdomain - invalid hash",
host: "badhash.example.com",
errorMsg: "no entry for hash \"badhash\" in tasks registry",
},
{
name: "direct hash subdomain - single part (no dots)",
host: task3Hash,
expectedID: task3.operationID,
},

// Misc errors
{
name: "invalid alias domain format",
host: "part1-part2.example.com",
errorMsg: "no entry for hash \"part1-part2\" in tasks registry",
},
{
name: "empty host with other headers",
headers: map[string]string{
"authorization": "Bearer token",
},
errorMsg: "authority (host) or x-yt-taskproxy-id headers are missing in request",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headers := tt.headers
if headers == nil {
headers = map[string]string{}
}
task, err := server.findTaskByRequest(tt.host, headers)

if tt.errorMsg != "" {
require.ErrorContains(t, err, tt.errorMsg)
assert.Nil(t, task)
} else {
require.NoError(t, err)
require.NotNil(t, task)
assert.Equal(t, tt.expectedID, task.operationID)
}
})
}
}
2 changes: 1 addition & 1 deletion server/pkg/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func (d *taskDiscovery) listOperations(ctx context.Context) ([]ytsdk.OperationSt

for {
d.logger.Debugf(
"loading running operations chunk, limit %d, cursor %s, already loaded %d operations",
"loading running operations chunk, limit %d, cursor %v, already loaded %d operations",
limit,
cursor,
len(operations),
Expand Down
24 changes: 22 additions & 2 deletions server/pkg/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,25 @@ type Task struct {
jobs []HostPort
}

var valueRegexp = regexp.MustCompile(`^[a-z0-9]+$`)
var valueRegexp = regexp.MustCompile(`^[a-z0-9]{1,30}$`)

// Identifies task, for sorting and domain hash
func (t *Task) ID() string {
return t.operationID + t.taskName + t.service
}

func (t *Task) Hash() string {
return Hash([]byte(t.ID()))
}

func (t *Task) OperationID() string {
return t.operationID
}

func (t *Task) OperationAlias() string {
return t.operationAlias
}

// ID with jobs (host, port)-s to create correct version for xDS data (jobs can move between hosts)
func (t *Task) IDWithHostPort() string {
sb := strings.Builder{}
Expand Down Expand Up @@ -91,6 +103,14 @@ func getTaskAliasDomain(task Task, baseDomain string) string {
return fmt.Sprintf("%s-%s-%s.%s", task.operationAlias, task.taskName, task.service, baseDomain)
}

func tryParseAliasSubdomain(subdomain string) (string, string, string, bool) {
parts := strings.Split(subdomain, "-")
if len(parts) != 3 {
return "", "", "", false
}
return parts[0], parts[1], parts[2], true
}

func Hash(source []byte) string {
hash := fmt.Sprintf("%x", sha256.Sum256(source))
return hash[len(hash)-8:]
Expand All @@ -105,7 +125,7 @@ func (a TaskList) Less(i, j int) bool { return a[i].ID() < a[j].ID() }
func (a TaskList) String() string {
sb := strings.Builder{}
for _, task := range a {
sb.WriteString(fmt.Sprintf("\t%v\n", task))
fmt.Fprintf(&sb, "\t%v\n", task)
}
return sb.String()
}
16 changes: 13 additions & 3 deletions server/pkg/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestValidateTask(t *testing.T) {
taskName: "task",
service: "service",
},
err: errors.New("field \"operationAlias\" value \"ali-as\" does not match regexp \"^[a-z0-9]+$\""),
err: errors.New("field \"operationAlias\" value \"ali-as\" does not match regexp \"^[a-z0-9]{1,30}$\""),
},
{
name: "invalid task name",
Expand All @@ -40,7 +40,7 @@ func TestValidateTask(t *testing.T) {
taskName: "Task",
service: "service",
},
err: errors.New("field \"taskName\" value \"Task\" does not match regexp \"^[a-z0-9]+$\""),
err: errors.New("field \"taskName\" value \"Task\" does not match regexp \"^[a-z0-9]{1,30}$\""),
},
{
name: "invalid service",
Expand All @@ -50,7 +50,17 @@ func TestValidateTask(t *testing.T) {
taskName: "task",
service: "$ervice",
},
err: errors.New("field \"service\" value \"$ervice\" does not match regexp \"^[a-z0-9]+$\""),
err: errors.New("field \"service\" value \"$ervice\" does not match regexp \"^[a-z0-9]{1,30}$\""),
},
{
name: "invalid service",
task: Task{
operationID: "123",
operationAlias: "alias",
taskName: "task",
service: "serviceserviceserviceserviceserviceserviceserviceservice",
},
err: errors.New("field \"service\" value \"serviceserviceserviceserviceserviceserviceserviceservice\" does not match regexp \"^[a-z0-9]{1,30}$\""),
},
{
name: "do not check if no alias",
Expand Down
Loading
Loading