diff --git a/database/dynamodb/errors.go b/database/dynamodb/errors.go new file mode 100644 index 0000000..e4629a7 --- /dev/null +++ b/database/dynamodb/errors.go @@ -0,0 +1,36 @@ +package dynamodb + +import ( + "github.com/aws/aws-sdk-go-v2/aws/awserr" + awsdynamodb "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +// CheckConditionalCheckFailed maps a DynamoDB ConditionalCheckFailedException +// to the provided error. This is the DynamoDB equivalent of checking for a +// unique constraint violation. +func CheckConditionalCheckFailed(inErr, outErr error) error { + if inErr != nil { + if awsErr, ok := inErr.(awserr.Error); ok { + if awsErr.Code() == awsdynamodb.ErrCodeConditionalCheckFailedException { + return outErr + } + } + } + return inErr +} + +// IsConditionalCheckFailed returns true if the error is a DynamoDB +// ConditionalCheckFailedException. +func IsConditionalCheckFailed(err error) bool { + if err == nil { + return false + } + + if awsErr, ok := err.(awserr.Error); ok { + if awsErr.Code() == awsdynamodb.ErrCodeConditionalCheckFailedException { + return true + } + } + + return false +} diff --git a/database/dynamodb/test/util.go b/database/dynamodb/test/util.go new file mode 100644 index 0000000..f61a885 --- /dev/null +++ b/database/dynamodb/test/util.go @@ -0,0 +1,65 @@ +package test + +import ( + "context" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/defaults" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/ory/dockertest/v3" + "github.com/pkg/errors" + + "github.com/code-payments/ocp-server/retry" + "github.com/code-payments/ocp-server/retry/backoff" +) + +const ( + containerRepository = "amazon/dynamodb-local" + containerTag = "latest" + containerAutoKill = 120 // seconds + + port = 8000 +) + +// StartDynamoDB starts a Docker container using the amazon/dynamodb-local image and returns a DynamoDB client for testing purposes. +func StartDynamoDB(pool *dockertest.Pool) (client *dynamodb.Client, closeFunc func(), err error) { + closeFunc = func() {} + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: containerRepository, + Tag: containerTag, + Cmd: []string{"-jar", "DynamoDBLocal.jar", "-inMemory"}, + }) + if err != nil { + return nil, closeFunc, errors.Wrap(err, "failed to start resource") + } + + resource.Expire(containerAutoKill) + + endpoint := fmt.Sprintf("http://localhost:%s", resource.GetPort(fmt.Sprintf("%d/tcp", port))) + + cfg := defaults.Config() + cfg.Region = "us-east-1" + cfg.EndpointResolver = aws.ResolveWithEndpointURL(endpoint) + cfg.Credentials = aws.NewStaticCredentialsProvider("dummy", "dummy", "") + + client = dynamodb.New(cfg) + + // Wait for the container to be ready by issuing a ListTables request. + _, err = retry.Retry( + func() error { + _, listErr := client.ListTablesRequest(&dynamodb.ListTablesInput{}).Send(context.Background()) + return listErr + }, + retry.Limit(50), + retry.Backoff(backoff.Constant(500*time.Millisecond), 500*time.Second), + ) + if err != nil { + resource.Close() + return nil, closeFunc, errors.Wrap(err, "timed out waiting for dynamodb container to become available") + } + + return client, func() { resource.Close() }, nil +} diff --git a/ocp/data/internal.go b/ocp/data/internal.go index 4f5809f..d229ee5 100644 --- a/ocp/data/internal.go +++ b/ocp/data/internal.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/code-payments/ocp-server/cache" @@ -27,7 +26,6 @@ import ( "github.com/code-payments/ocp-server/ocp/data/deposit" "github.com/code-payments/ocp-server/ocp/data/fulfillment" "github.com/code-payments/ocp-server/ocp/data/intent" - "github.com/code-payments/ocp-server/ocp/data/messaging" "github.com/code-payments/ocp-server/ocp/data/nonce" "github.com/code-payments/ocp-server/ocp/data/rendezvous" "github.com/code-payments/ocp-server/ocp/data/swap" @@ -45,7 +43,6 @@ import ( deposit_memory_client "github.com/code-payments/ocp-server/ocp/data/deposit/memory" fulfillment_memory_client "github.com/code-payments/ocp-server/ocp/data/fulfillment/memory" intent_memory_client "github.com/code-payments/ocp-server/ocp/data/intent/memory" - messaging_memory_client "github.com/code-payments/ocp-server/ocp/data/messaging/memory" nonce_memory_client "github.com/code-payments/ocp-server/ocp/data/nonce/memory" rendezvous_memory_client "github.com/code-payments/ocp-server/ocp/data/rendezvous/memory" swap_memory_client "github.com/code-payments/ocp-server/ocp/data/swap/memory" @@ -63,7 +60,6 @@ import ( deposit_postgres_client "github.com/code-payments/ocp-server/ocp/data/deposit/postgres" fulfillment_postgres_client "github.com/code-payments/ocp-server/ocp/data/fulfillment/postgres" intent_postgres_client "github.com/code-payments/ocp-server/ocp/data/intent/postgres" - messaging_postgres_client "github.com/code-payments/ocp-server/ocp/data/messaging/postgres" nonce_postgres_client "github.com/code-payments/ocp-server/ocp/data/nonce/postgres" rendezvous_postgres_client "github.com/code-payments/ocp-server/ocp/data/rendezvous/postgres" swap_postgres_client "github.com/code-payments/ocp-server/ocp/data/swap/postgres" @@ -182,12 +178,6 @@ type DatabaseData interface { GetTransactedAmountForAntiMoneyLaundering(ctx context.Context, owner string, since time.Time) (uint64, float64, error) GetUsdCostBasis(ctx context.Context, owner string, mint string) (float64, error) - // Messaging - // -------------------------------------------------------------------------------- - CreateMessage(ctx context.Context, record *messaging.Record) error - GetMessages(ctx context.Context, account string) ([]*messaging.Record, error) - DeleteMessage(ctx context.Context, account string, messageID uuid.UUID) error - // Nonces // -------------------------------------------------------------------------------- GetNonce(ctx context.Context, address string) (*nonce.Record, error) @@ -276,7 +266,6 @@ type DatabaseProvider struct { deposits deposit.Store fulfillments fulfillment.Store intents intent.Store - messages messaging.Store nonces nonce.Store rendezvous rendezvous.Store swaps swap.Store @@ -322,7 +311,6 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { deposits: deposit_postgres_client.New(db), fulfillments: fulfillment_postgres_client.New(db), intents: intent_postgres_client.New(db), - messages: messaging_postgres_client.New(db), nonces: nonce_postgres_client.New(db), rendezvous: rendezvous_postgres_client.New(db), swaps: swap_postgres_client.New(db), @@ -349,7 +337,6 @@ func NewTestDatabaseProvider() DatabaseData { deposits: deposit_memory_client.New(), fulfillments: fulfillment_memory_client.New(), intents: intent_memory_client.New(), - messages: messaging_memory_client.New(), nonces: nonce_memory_client.New(), rendezvous: rendezvous_memory_client.New(), swaps: swap_memory_client.New(), @@ -677,18 +664,6 @@ func (dp *DatabaseProvider) GetUsdCostBasis(ctx context.Context, owner string, m return dp.intents.GetUsdCostBasis(ctx, owner, mint) } -// Messaging -// -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) CreateMessage(ctx context.Context, record *messaging.Record) error { - return dp.messages.Insert(ctx, record) -} -func (dp *DatabaseProvider) GetMessages(ctx context.Context, account string) ([]*messaging.Record, error) { - return dp.messages.Get(ctx, account) -} -func (dp *DatabaseProvider) DeleteMessage(ctx context.Context, account string, messageID uuid.UUID) error { - return dp.messages.Delete(ctx, account, messageID) -} - // Nonces // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) GetNonce(ctx context.Context, address string) (*nonce.Record, error) { diff --git a/ocp/data/messaging/dynamodb/model.go b/ocp/data/messaging/dynamodb/model.go new file mode 100644 index 0000000..b511eae --- /dev/null +++ b/ocp/data/messaging/dynamodb/model.go @@ -0,0 +1,128 @@ +package dynamodb + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/google/uuid" + "github.com/pkg/errors" + + dynamodbutil "github.com/code-payments/ocp-server/database/dynamodb" + + "github.com/code-payments/ocp-server/ocp/data/messaging" +) + +type model struct { + Account string + MessageID string + Message []byte + CreatedAt time.Time +} + +func toModel(record *messaging.Record) (*model, error) { + if err := record.Validate(); err != nil { + return nil, err + } + + return &model{ + Account: record.Account, + MessageID: record.MessageID.String(), + Message: record.Message, + // The only time we call toModel is on create, so it's fine to default + // to UTC now. + CreatedAt: time.Now().UTC(), + }, nil +} + +func fromModel(m *model) (*messaging.Record, error) { + parsedMessageID, err := uuid.Parse(m.MessageID) + if err != nil { + return nil, errors.Wrap(err, "failure parsing message id") + } + + return &messaging.Record{ + Account: m.Account, + MessageID: parsedMessageID, + Message: m.Message, + }, nil +} + +func (m *model) dbPut(ctx context.Context, client *dynamodb.Client, tableName string) error { + req := client.PutItemRequest(&dynamodb.PutItemInput{ + TableName: aws.String(tableName), + Item: map[string]dynamodb.AttributeValue{ + "account": {S: aws.String(m.Account)}, + "message_id": {S: aws.String(m.MessageID)}, + "message": {B: m.Message}, + "created_at": {N: aws.String(fmt.Sprintf("%d", m.CreatedAt.Unix()))}, + }, + ConditionExpression: aws.String("attribute_not_exists(account) AND attribute_not_exists(message_id)"), + }) + + _, err := req.Send(ctx) + if err != nil { + return dynamodbutil.CheckConditionalCheckFailed(err, messaging.ErrDuplicateMessageID) + } + + return nil +} + +func dbGetAllForAccount(ctx context.Context, client *dynamodb.Client, tableName string, account string) ([]*model, error) { + req := client.QueryRequest(&dynamodb.QueryInput{ + TableName: aws.String(tableName), + KeyConditionExpression: aws.String("account = :account"), + ExpressionAttributeValues: map[string]dynamodb.AttributeValue{ + ":account": {S: aws.String(account)}, + }, + }) + + resp, err := req.Send(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to query messages") + } + + models := make([]*model, len(resp.Items)) + for i, item := range resp.Items { + m := &model{} + if v, ok := item["account"]; ok && v.S != nil { + m.Account = *v.S + } + if v, ok := item["message_id"]; ok && v.S != nil { + m.MessageID = *v.S + } + if v, ok := item["message"]; ok { + m.Message = v.B + } + if v, ok := item["created_at"]; ok && v.N != nil { + seconds, err := strconv.ParseInt(*v.N, 10, 64) + if err != nil { + return nil, errors.Wrap(err, "failed to parse created_at") + } + m.CreatedAt = time.Unix(seconds, 0).UTC() + } + models[i] = m + } + + return models, nil +} + +func dbDelete(ctx context.Context, client *dynamodb.Client, tableName string, account, messageID string) error { + req := client.DeleteItemRequest(&dynamodb.DeleteItemInput{ + TableName: aws.String(tableName), + Key: map[string]dynamodb.AttributeValue{ + "account": {S: aws.String(account)}, + "message_id": {S: aws.String(messageID)}, + }, + }) + + _, err := req.Send(ctx) + if err != nil { + return errors.Wrap(err, "failed to delete message") + } + + return nil +} diff --git a/ocp/data/messaging/postgres/store.go b/ocp/data/messaging/dynamodb/store.go similarity index 52% rename from ocp/data/messaging/postgres/store.go rename to ocp/data/messaging/dynamodb/store.go index ac96187..d586001 100644 --- a/ocp/data/messaging/postgres/store.go +++ b/ocp/data/messaging/dynamodb/store.go @@ -1,25 +1,24 @@ -package postgres +package dynamodb import ( "context" - "database/sql" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/google/uuid" - "github.com/jmoiron/sqlx" "github.com/code-payments/ocp-server/ocp/data/messaging" ) -// todo: This doesn't support TTL expiries, which is fine for now. We can -// manually delete old entries while in an invite-only testing phase. type store struct { - db *sqlx.DB + client *dynamodb.Client + tableName string } -// New returns a postgres backed messaging.Store. -func New(db *sql.DB) messaging.Store { +// New returns a DynamoDB backed messaging.Store. +func New(client *dynamodb.Client, tableName string) messaging.Store { return &store{ - db: sqlx.NewDb(db, "pgx"), + client: client, + tableName: tableName, } } @@ -30,27 +29,25 @@ func (s *store) Insert(ctx context.Context, record *messaging.Record) error { return err } - return model.dbSave(ctx, s.db) + return model.dbPut(ctx, s.client, s.tableName) } // Delete implements messaging.Store.Delete. func (s *store) Delete(ctx context.Context, account string, messageID uuid.UUID) error { - return dbDelete(ctx, s.db, account, messageID.String()) + return dbDelete(ctx, s.client, s.tableName, account, messageID.String()) } // Get implements messaging.Store.Get. func (s *store) Get(ctx context.Context, account string) ([]*messaging.Record, error) { - models, err := dbGetAllForAccount(ctx, s.db, account) + models, err := dbGetAllForAccount(ctx, s.client, s.tableName, account) if err != nil { return nil, err } records := make([]*messaging.Record, len(models)) - for i, model := range models { - record, err := fromModel(model) + for i, m := range models { + record, err := fromModel(m) if err != nil { - // todo(safety): this is the equivalent QoS brick case, although should be less problematic. - // we could have a valve to ignore, and also to delete return nil, err } records[i] = record diff --git a/ocp/data/messaging/dynamodb/store_test.go b/ocp/data/messaging/dynamodb/store_test.go new file mode 100644 index 0000000..207c431 --- /dev/null +++ b/ocp/data/messaging/dynamodb/store_test.go @@ -0,0 +1,115 @@ +package dynamodb + +import ( + "context" + "os" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/ory/dockertest/v3" + "go.uber.org/zap" + + "github.com/code-payments/ocp-server/ocp/data/messaging" + "github.com/code-payments/ocp-server/ocp/data/messaging/tests" + + dynamodbtest "github.com/code-payments/ocp-server/database/dynamodb/test" +) + +const ( + testTableName = "test-messaging" +) + +var ( + testStore messaging.Store + teardown func() +) + +func TestMain(m *testing.M) { + log := zap.Must(zap.NewDevelopment()) + + testPool, err := dockertest.NewPool("") + if err != nil { + log.With(zap.Error(err)).Error("Error creating docker pool") + os.Exit(1) + } + + client, cleanUpFunc, err := dynamodbtest.StartDynamoDB(testPool) + if err != nil { + log.With(zap.Error(err)).Error("Error starting dynamodb-local image") + os.Exit(1) + } + + if err := createTestTable(client); err != nil { + log.With(zap.Error(err)).Error("Error creating test table") + cleanUpFunc() + os.Exit(1) + } + + testStore = New(client, testTableName) + teardown = func() { + if pc := recover(); pc != nil { + cleanUpFunc() + panic(pc) + } + + if err := resetTestTable(client); err != nil { + log.With(zap.Error(err)).Error("Error resetting test table") + cleanUpFunc() + os.Exit(1) + } + } + + code := m.Run() + cleanUpFunc() + os.Exit(code) +} + +func TestMessagingDynamoDBStore(t *testing.T) { + tests.RunTests(t, testStore, teardown) +} + +func createTestTable(client *dynamodb.Client) error { + req := client.CreateTableRequest(&dynamodb.CreateTableInput{ + TableName: aws.String(testTableName), + AttributeDefinitions: []dynamodb.AttributeDefinition{ + { + AttributeName: aws.String("account"), + AttributeType: dynamodb.ScalarAttributeTypeS, + }, + { + AttributeName: aws.String("message_id"), + AttributeType: dynamodb.ScalarAttributeTypeS, + }, + }, + KeySchema: []dynamodb.KeySchemaElement{ + { + AttributeName: aws.String("account"), + KeyType: dynamodb.KeyTypeHash, + }, + { + AttributeName: aws.String("message_id"), + KeyType: dynamodb.KeyTypeRange, + }, + }, + ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(5), + WriteCapacityUnits: aws.Int64(5), + }, + }) + + _, err := req.Send(context.Background()) + return err +} + +func resetTestTable(client *dynamodb.Client) error { + deleteReq := client.DeleteTableRequest(&dynamodb.DeleteTableInput{ + TableName: aws.String(testTableName), + }) + _, err := deleteReq.Send(context.Background()) + if err != nil { + return err + } + + return createTestTable(client) +} diff --git a/ocp/data/messaging/postgres/model.go b/ocp/data/messaging/postgres/model.go deleted file mode 100644 index 0a3646c..0000000 --- a/ocp/data/messaging/postgres/model.go +++ /dev/null @@ -1,101 +0,0 @@ -package postgres - -import ( - "context" - "database/sql" - "time" - - "github.com/google/uuid" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - - pgutil "github.com/code-payments/ocp-server/database/postgres" - - "github.com/code-payments/ocp-server/ocp/data/messaging" -) - -const ( - tableName = "ocp__core_message" -) - -type model struct { - Id sql.NullInt64 `db:"id"` - Account string `db:"account"` - MessageID string `db:"message_id"` - Message []byte `db:"message"` - CreatedAt time.Time `db:"created_at"` -} - -func toModel(record *messaging.Record) (*model, error) { - if err := record.Validate(); err != nil { - return nil, err - } - - if len(record.Account) == 0 { - return nil, errors.New("empty account") - } - - if record.Message == nil || len(record.Message) == 0 { - return nil, errors.New("empty message id") - } - - return &model{ - Account: record.Account, - MessageID: record.MessageID.String(), - Message: record.Message, - // The only time we call toModel is on create, so it's fine to default - // to UTC now. - CreatedAt: time.Now().UTC(), - }, nil -} - -func fromModel(obj *model) (*messaging.Record, error) { - parsedMessageID, err := uuid.Parse(obj.MessageID) - if err != nil { - return nil, errors.Wrap(err, "failure parsing message id") - } - - return &messaging.Record{ - Account: obj.Account, - MessageID: parsedMessageID, - Message: obj.Message, - }, nil -} - -func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error { - return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { - query := `INSERT INTO ` + tableName + ` - ( - account, message_id, message, created_at - ) VALUES ($1,$2,$3,$4) RETURNING *;` - - err := tx.QueryRowxContext(ctx, query, - m.Account, - m.MessageID, - m.Message, - m.CreatedAt, - ).StructScan(m) - - return pgutil.CheckUniqueViolation(err, messaging.ErrDuplicateMessageID) - }) -} - -func dbGetAllForAccount(ctx context.Context, db *sqlx.DB, account string) ([]*model, error) { - res := []*model{} - - query := `SELECT account, message_id, message FROM ` + tableName + ` - WHERE account = $1` - - err := db.SelectContext(ctx, &res, query, account) - if err != nil { - return nil, err - } - return res, nil -} - -func dbDelete(ctx context.Context, db *sqlx.DB, account, messageID string) error { - query := `DELETE FROM ` + tableName + ` - WHERE account = $1 AND message_id = $2;` - _, err := db.ExecContext(ctx, query, account, messageID) - return err -} diff --git a/ocp/data/messaging/postgres/model_test.go b/ocp/data/messaging/postgres/model_test.go deleted file mode 100644 index c8f97ae..0000000 --- a/ocp/data/messaging/postgres/model_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package postgres - -import ( - "crypto/ed25519" - "testing" - - "github.com/google/uuid" - "github.com/mr-tron/base58" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/proto" - - commonpb "github.com/code-payments/ocp-protobuf-api/generated/go/common/v1" - messagingpb "github.com/code-payments/ocp-protobuf-api/generated/go/messaging/v1" - - "github.com/code-payments/ocp-server/ocp/data/messaging" -) - -func TestModelConversion(t *testing.T) { - pub, _, err := ed25519.GenerateKey(nil) - require.NoError(t, err) - requestor, _, err := ed25519.GenerateKey(nil) - require.NoError(t, err) - - account := &commonpb.SolanaAccountId{Value: pub} - messageID := uuid.New() - idBytes, _ := messageID.MarshalBinary() - message := &messagingpb.Message{ - Id: &messagingpb.MessageId{ - Value: idBytes, - }, - Kind: &messagingpb.Message_RequestToGrabBill{ - RequestToGrabBill: &messagingpb.RequestToGrabBill{ - RequestorAccount: &commonpb.SolanaAccountId{ - Value: requestor, - }, - }, - }, - } - messageBytes, err := proto.Marshal(message) - require.NoError(t, err) - - record := &messaging.Record{ - Account: base58.Encode(account.Value), - MessageID: messageID, - Message: messageBytes, - } - - model, err := toModel(record) - require.NoError(t, err) - assert.Equal(t, model.Account, base58.Encode(account.Value)) - assert.Equal(t, model.MessageID, messageID.String()) - assert.Equal(t, model.Message, messageBytes) - - actual, err := fromModel(model) - require.NoError(t, err) - assert.Equal(t, actual, record) -} diff --git a/ocp/data/messaging/postgres/store_test.go b/ocp/data/messaging/postgres/store_test.go deleted file mode 100644 index e3404a9..0000000 --- a/ocp/data/messaging/postgres/store_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package postgres - -import ( - "database/sql" - "os" - "testing" - - "github.com/ory/dockertest/v3" - "go.uber.org/zap" - - "github.com/code-payments/ocp-server/ocp/data/messaging" - "github.com/code-payments/ocp-server/ocp/data/messaging/tests" - - postgrestest "github.com/code-payments/ocp-server/database/postgres/test" - - _ "github.com/jackc/pgx/v4/stdlib" -) - -const ( - // Used for testing ONLY, the table and migrations are external to this repository - tableCreate = ` - CREATE TABLE ocp__core_message ( - id SERIAL NOT NULL PRIMARY KEY, - - account TEXT NOT NULL, - message_id UUID NOT NULL, - message BYTEA NOT NULL, - created_at TIMESTAMP WITH TIME ZONE, - - CONSTRAINT ocp__core_message__uniq__account__and__message_id UNIQUE (account, message_id) - ); - ` - - // Used for testing ONLY, the table and migrations are external to this repository - tableDestroy = ` - DROP TABLE ocp__core_message; - ` -) - -var ( - testStore messaging.Store - teardown func() -) - -func TestMain(m *testing.M) { - log := zap.Must(zap.NewDevelopment()) - - testPool, err := dockertest.NewPool("") - if err != nil { - log.With(zap.Error(err)).Error("Error creating docker pool") - os.Exit(1) - } - - var cleanUpFunc func() - db, cleanUpFunc, err := postgrestest.StartPostgresDB(testPool) - if err != nil { - log.With(zap.Error(err)).Error("Error starting postgres image") - os.Exit(1) - } - defer db.Close() - - if err := createTestTables(log, db); err != nil { - log.With(zap.Error(err)).Error("Error creating test tables") - cleanUpFunc() - os.Exit(1) - } - - testStore = New(db) - teardown = func() { - if pc := recover(); pc != nil { - cleanUpFunc() - panic(pc) - } - - if err := resetTestTables(log, db); err != nil { - log.With(zap.Error(err)).Error("Error resetting test tables") - cleanUpFunc() - os.Exit(1) - } - } - - code := m.Run() - cleanUpFunc() - os.Exit(code) -} - -func TestMessagingPostgresStore(t *testing.T) { - tests.RunTests(t, testStore, teardown) -} - -func createTestTables(log *zap.Logger, db *sql.DB) error { - _, err := db.Exec(tableCreate) - if err != nil { - log.With(zap.Error(err)).Error("could not create test tables") - return err - } - return nil -} - -func resetTestTables(log *zap.Logger, db *sql.DB) error { - _, err := db.Exec(tableDestroy) - if err != nil { - log.With(zap.Error(err)).Error("could not drop test tables") - return err - } - - return createTestTables(log, db) -} diff --git a/ocp/data/provider.go b/ocp/data/provider.go index 5ecf5f4..33efab7 100644 --- a/ocp/data/provider.go +++ b/ocp/data/provider.go @@ -8,6 +8,7 @@ const ( maxCurrencyHistoryReqSize = 1024 ) +// todo: Deprecate Provider in favour of per-store interfaces type Provider interface { BlockchainData DatabaseData diff --git a/ocp/rpc/messaging/internal.go b/ocp/rpc/messaging/internal.go index 63314de..3976086 100644 --- a/ocp/rpc/messaging/internal.go +++ b/ocp/rpc/messaging/internal.go @@ -15,10 +15,10 @@ import ( commonpb "github.com/code-payments/ocp-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/ocp-protobuf-api/generated/go/messaging/v1" + "github.com/code-payments/ocp-server/grpc/headers" "github.com/code-payments/ocp-server/ocp/common" "github.com/code-payments/ocp-server/ocp/data/messaging" "github.com/code-payments/ocp-server/ocp/data/rendezvous" - "github.com/code-payments/ocp-server/grpc/headers" "github.com/code-payments/ocp-server/retry" "github.com/code-payments/ocp-server/retry/backoff" ) @@ -81,7 +81,7 @@ func (s *server) InternallyCreateMessage(ctx context.Context, rendezvousKey *com } // Save the message to the DB - err = s.data.CreateMessage(ctx, record) + err = s.messageStore.Insert(ctx, record) if err != nil { return uuid.Nil, errors.Wrap(err, "error saving message to db") } diff --git a/ocp/rpc/messaging/server.go b/ocp/rpc/messaging/server.go index a7fc008..75282e8 100644 --- a/ocp/rpc/messaging/server.go +++ b/ocp/rpc/messaging/server.go @@ -47,7 +47,9 @@ const ( type server struct { log *zap.Logger conf *conf - data ocp_data.Provider + + data ocp_data.Provider + messageStore messaging.Store streamsMu sync.RWMutex streams map[string]*messageStream @@ -66,10 +68,12 @@ type server struct { func NewMessagingClient( log *zap.Logger, data ocp_data.Provider, + messageStore messaging.Store, ) InternalMessageClient { return &server{ - log: log, - data: data, + log: log, + data: data, + messageStore: messageStore, } } @@ -79,6 +83,7 @@ func NewMessagingClient( func NewMessagingClientAndServer( log *zap.Logger, data ocp_data.Provider, + messageStore messaging.Store, rpcSignatureVerifier *auth.RPCSignatureVerifier, broadcastAddress string, configProvider ConfigProvider, @@ -87,6 +92,7 @@ func NewMessagingClientAndServer( log: log, conf: configProvider(), data: data, + messageStore: messageStore, streams: make(map[string]*messageStream), individualStreamMu: make(map[string]*sync.Mutex), rpcSignatureVerifier: rpcSignatureVerifier, @@ -505,7 +511,7 @@ func (s *server) PollMessages(ctx context.Context, req *messagingpb.PollMessages return nil, err } - records, err := s.data.GetMessages(ctx, rendezvousAccount.PublicKey().ToBase58()) + records, err := s.messageStore.Get(ctx, rendezvousAccount.PublicKey().ToBase58()) if err != nil { log.With(zap.Error(err)).Warn("failed to load undelivered messages") return nil, status.Error(codes.Internal, "") @@ -556,7 +562,7 @@ func (s *server) AckMessages(ctx context.Context, req *messagingpb.AckMessagesRe return nil, status.Error(codes.Internal, "") } - if err := s.data.DeleteMessage(ctx, account, converted); err != nil { + if err := s.messageStore.Delete(ctx, account, converted); err != nil { log.With(zap.Error(err)).Warn("Failed to delete message") return nil, status.Error(codes.Internal, "") } @@ -688,7 +694,7 @@ func (s *server) SendMessage(ctx context.Context, req *messagingpb.SendMessageRe Message: messageWithGeneratedIDAndSignatureBytes, } - err = s.data.CreateMessage(ctx, record) + err = s.messageStore.Insert(ctx, record) if err != nil { log.With(zap.Error(err)).Warn("failed to create message") return err @@ -750,7 +756,7 @@ func (s *server) flush(ctx context.Context, accountID *messagingpb.RendezvousKey zap.String("account_id", accountStr), ) - records, err := s.data.GetMessages(ctx, accountStr) + records, err := s.messageStore.Get(ctx, accountStr) if err != nil { log.With(zap.Error(err)).Warn("Failed to load undelivered messages") return diff --git a/ocp/rpc/messaging/testutil.go b/ocp/rpc/messaging/testutil.go index 405e7b5..4777311 100644 --- a/ocp/rpc/messaging/testutil.go +++ b/ocp/rpc/messaging/testutil.go @@ -26,6 +26,7 @@ import ( "github.com/code-payments/ocp-server/ocp/data/account" "github.com/code-payments/ocp-server/ocp/data/currency" "github.com/code-payments/ocp-server/ocp/data/messaging" + "github.com/code-payments/ocp-server/ocp/data/messaging/memory" "github.com/code-payments/ocp-server/ocp/data/rendezvous" "github.com/code-payments/ocp-server/testutil" ) @@ -47,6 +48,7 @@ func setup(t *testing.T, enableMultiServer bool) (env testEnv, cleanup func()) { require.NoError(t, err) data := ocp_data.NewTestDataProvider() + messageStore := memory.New() env.client1 = &clientEnv{ ctx: context.Background(), @@ -75,14 +77,14 @@ func setup(t *testing.T, enableMultiServer bool) (env testEnv, cleanup func()) { }, })) - s1 := NewMessagingClientAndServer(log, data, auth.NewRPCSignatureVerifier(log, data), conn1.Target(), withManualTestOverrides(&testOverrides{})) + s1 := NewMessagingClientAndServer(log, data, messageStore, auth.NewRPCSignatureVerifier(log, data), conn1.Target(), withManualTestOverrides(&testOverrides{})) env.server1 = &serverEnv{ ctx: context.Background(), server: s1, subsidizer: subsidizer, } - s2 := NewMessagingClientAndServer(log, data, auth.NewRPCSignatureVerifier(log, data), conn2.Target(), withManualTestOverrides(&testOverrides{})) + s2 := NewMessagingClientAndServer(log, data, messageStore, auth.NewRPCSignatureVerifier(log, data), conn2.Target(), withManualTestOverrides(&testOverrides{})) env.server2 = &serverEnv{ ctx: context.Background(), server: s2, @@ -115,13 +117,13 @@ type serverEnv struct { } func (s *serverEnv) getMessages(t *testing.T, rendezvousKey *common.Account) []*messaging.Record { - messages, err := s.server.data.GetMessages(s.ctx, rendezvousKey.PublicKey().ToBase58()) + messages, err := s.server.messageStore.Get(s.ctx, rendezvousKey.PublicKey().ToBase58()) require.NoError(t, err) return messages } func (s *serverEnv) assertNoMessages(t *testing.T, rendezvousKey *common.Account) { - messages, err := s.server.data.GetMessages(s.ctx, rendezvousKey.PublicKey().ToBase58()) + messages, err := s.server.messageStore.Get(s.ctx, rendezvousKey.PublicKey().ToBase58()) require.NoError(t, err) assert.Empty(t, messages) }