From dcb30c5ae51c900b7b2d31b94339c978d58fbf17 Mon Sep 17 00:00:00 2001 From: Sergei Sadov Date: Tue, 1 Apr 2025 13:04:24 +0200 Subject: [PATCH] add ConditionalUpdateWithUpdateExpressions --- dynamo_repository.go | 67 ++++++++---- dynamo_repository_interface.go | 4 + mock/dynamo_api_mock_helper.go | 93 +++++++++++++++-- repository_save_test.go | 4 +- repository_update_test.go | 185 ++++++++++++++++++++++----------- update_option.go | 15 +++ 6 files changed, 281 insertions(+), 87 deletions(-) create mode 100644 update_option.go diff --git a/dynamo_repository.go b/dynamo_repository.go index dad682d..e5565a2 100644 --- a/dynamo_repository.go +++ b/dynamo_repository.go @@ -191,6 +191,18 @@ func (repository *Repository) prepareUpdateWithUpdateExpressions( return update, nil } +// ConditionalUpdateWithUpdateExpressions updates an item with update expressions and optional conditions defined at field level, enabling you to set +// different update expressions for each field. The first key of the updateMap specifies the Update expression to use +// for the expressions in the map +func (repository *Repository) ConditionalUpdateWithUpdateExpressions( + ctx context.Context, + key KeyInterface, + updateExpressions UpdateExpressions, + updateOptions ...UpdateOption, +) (bool, error) { + return repository.updateWithUpdateExpressions(ctx, key, updateExpressions, updateOptions...) +} + // UpdateWithUpdateExpressions updates an item with update expressions defined at field level, enabling you to set // different update expressions for each field. The first key of the updateMap specifies the Update expression to use // for the expressions in the map @@ -199,24 +211,8 @@ func (repository *Repository) UpdateWithUpdateExpressions( key KeyInterface, updateExpressions UpdateExpressions, ) error { - update, err := repository.prepareUpdateWithUpdateExpressions(ctx, key, updateExpressions) - if err != nil { - repository.log.error(ctx, key.TableName(), err.Error()) - return err - } - - err = update.RunWithContext(ctx) - if err != nil { - repository.log.error(ctx, key.TableName(), err.Error()) - return err - } - - err = repository.metrics.Publish(ctx, key.TableName(), MetricNameUpdatedItemsCount, float64(1)) - if err != nil { - repository.log.error(ctx, key.TableName(), err.Error()) - } - - return nil + _, err := repository.updateWithUpdateExpressions(ctx, key, updateExpressions) + return err } // UpdateWithUpdateExpressionsAndReturnValue updates an item with update expressions defined at field level and returns @@ -593,3 +589,38 @@ func (repository Repository) ScanIteratorWithContext(ctx context.Context, key Ke return itr, nil } + +func (repository *Repository) updateWithUpdateExpressions( + ctx context.Context, + key KeyInterface, + updateExpressions UpdateExpressions, + updateOptions ...UpdateOption, +) (bool, error) { + update, err := repository.prepareUpdateWithUpdateExpressions(ctx, key, updateExpressions) + if err != nil { + repository.log.error(ctx, key.TableName(), err.Error()) + return false, err + } + + for i := range updateOptions { + updateOptions[i](update) + } + + err = update.RunWithContext(ctx) + if err != nil { + if awserr, ok := err.(awserr.Error); ok && awserr.Code() == dynamodb.ErrCodeConditionalCheckFailedException && len(updateOptions) > 0 { + repository.log.info(ctx, key.TableName(), dynamodb.ErrCodeConditionalCheckFailedException) + return false, nil + } + + repository.log.error(ctx, key.TableName(), err.Error()) + return false, err + } + + err = repository.metrics.Publish(ctx, key.TableName(), MetricNameUpdatedItemsCount, float64(1)) + if err != nil { + repository.log.error(ctx, key.TableName(), err.Error()) + } + + return true, nil +} diff --git a/dynamo_repository_interface.go b/dynamo_repository_interface.go index 54a24bf..f5ec53e 100644 --- a/dynamo_repository_interface.go +++ b/dynamo_repository_interface.go @@ -57,6 +57,10 @@ type RepositoryInterface interface { // for the expressions in the map UpdateWithUpdateExpressions(ctx context.Context, key KeyInterface, updateExpressions UpdateExpressions) error + // ConditionalUpdateWithUpdateExpressions updates an item with update expressions and optional conditions defined at field level + // if no conditions were provided within UpdateOption, a normal update will be performed + ConditionalUpdateWithUpdateExpressions(ctx context.Context, key KeyInterface, updateExpressions UpdateExpressions, updateOptions ...UpdateOption) (bool, error) + // UpdateWithUpdateExpressionsAndReturnValue updates an item with update expressions defined at field level and returns // the item, as it appears after the update, enabling you to set different update expressions for each field. The first // key of the updateMap specifies the Update expression to use for the expressions in the map diff --git a/mock/dynamo_api_mock_helper.go b/mock/dynamo_api_mock_helper.go index e8c26a2..5fec21d 100644 --- a/mock/dynamo_api_mock_helper.go +++ b/mock/dynamo_api_mock_helper.go @@ -23,6 +23,7 @@ type DynamoMock struct { QueryOutput *dynamodb.QueryOutput ScanAllOutput *dynamodb.ScanOutput Input *dynamodb.PutItemInput + UpdateItemInput *dynamodb.UpdateItemInput DeleteItemInput *dynamodb.DeleteItemInput Inputs *dynamodb.BatchWriteItemInput DeleteInputs *dynamodb.BatchWriteItemInput @@ -213,6 +214,37 @@ func (d *DynamoMock) WithInput(value map[string]interface{}) DynamoDBOption { } } +// WithUpdateItemInput register option dynamodb UpdateItemInput +func (d *DynamoMock) WithUpdateItemInput(updateExpr string, value interface{}, opts ...UpdateOption) DynamoDBOption { + return func(args *DynamoMock) { + if d.ExpressionAttributeValues == nil { + d.ExpressionAttributeValues = make(map[string]*dynamodb.AttributeValue) + } + + expr := d.prepareExpression(updateExpr, value) + for i := range expr.avFields { + d.ExpressionAttributeValues[expr.avFields[i]] = expr.marshaledAVs[i] + } + + args.UpdateItemInput = &dynamodb.UpdateItemInput{ + Key: d.Hash, + UpdateExpression: &expr.preparedExpr, + TableName: aws.String(d.TableName), + ReturnValues: aws.String("NONE"), + } + for i := range opts { + opts[i](d) + } + + args.UpdateItemInput.ExpressionAttributeValues = d.ExpressionAttributeValues + if d.ConditionExpression != nil { + args.UpdateItemInput.ConditionExpression = d.ConditionExpression + } + + args.InputMatcher = gomock.Eq(args.UpdateItemInput) + } +} + // WithInput register option dynamodb PutItemInput func (d *DynamoMock) WithDeleteInput(value map[string]interface{}) DynamoDBOption { return func(args *DynamoMock) { @@ -313,12 +345,16 @@ func (d *DynamoMock) WithCondition(field string, value interface{}, operator str // WithConditionExpression register option dynamodb GetItemOutput func (d *DynamoMock) WithConditionExpression(expression string, value interface{}) DynamoDBOption { return func(args *DynamoMock) { - d.ExpressionAttributeValues = make(map[string]*dynamodb.AttributeValue) - expressionAttributeValueField := ":v0" - expression = strings.Replace(expression, "?", expressionAttributeValueField, 1) - av, _ := dynamodbattribute.Marshal(value) - d.ExpressionAttributeValues[expressionAttributeValueField] = av - d.ConditionExpression = &expression + if d.ExpressionAttributeValues == nil { + d.ExpressionAttributeValues = make(map[string]*dynamodb.AttributeValue) + } + + expr := d.prepareExpression(expression, value) + for i := range expr.avFields { + d.ExpressionAttributeValues[expr.avFields[i]] = expr.marshaledAVs[i] + } + + d.ConditionExpression = &expr.preparedExpr } } @@ -521,6 +557,28 @@ func (d *DynamoMock) addCall(method string, input interface{}, output interface{ return d } +func (d *DynamoMock) prepareExpression(expr string, values ...interface{}) preparedExpression { + currentAttributeStartIndex := len(d.ExpressionAttributeValues) + + avFields := make([]string, 0, len(values)) + marshaledAVs := make([]*dynamodb.AttributeValue, 0, len(values)) + + for i, val := range values { + expressionAttributeValueField := ":v" + strconv.Itoa(currentAttributeStartIndex+i) + expr = strings.Replace(expr, "?", expressionAttributeValueField, 1) + + avFields = append(avFields, expressionAttributeValueField) + av, _ := dynamodbattribute.Marshal(val) + marshaledAVs = append(marshaledAVs, av) + } + + return preparedExpression{ + avFields: avFields, + preparedExpr: expr, + marshaledAVs: marshaledAVs, + } +} + // getAttributeValue return dynamodb.AttributeValue from interface type func getAttributeValue(value interface{}) *dynamodb.AttributeValue { attributeValue := dynamodb.AttributeValue{} @@ -549,3 +607,26 @@ type call struct { output interface{} err interface{} } + +type preparedExpression struct { + avFields []string + preparedExpr string + marshaledAVs []*dynamodb.AttributeValue +} + +type UpdateOption func(m *DynamoMock) + +func WithCondition(conditionExpression string, conditionArgs ...interface{}) func(m *DynamoMock) { + return func(m *DynamoMock) { + if m.ExpressionAttributeValues == nil { + m.ExpressionAttributeValues = make(map[string]*dynamodb.AttributeValue) + } + + expr := m.prepareExpression(conditionExpression, conditionArgs...) + for i := range expr.avFields { + m.ExpressionAttributeValues[expr.avFields[i]] = expr.marshaledAVs[i] + } + + m.ConditionExpression = &expr.preparedExpr + } +} diff --git a/repository_save_test.go b/repository_save_test.go index 9466f29..abf338a 100644 --- a/repository_save_test.go +++ b/repository_save_test.go @@ -4,11 +4,12 @@ import ( "context" "time" - "github.com/adjoeio/djoemo/mock" "github.com/bouk/monkey" "github.com/pkg/errors" "go.uber.org/mock/gomock" + "github.com/adjoeio/djoemo/mock" + . "github.com/adjoeio/djoemo" ) @@ -341,7 +342,6 @@ var _ = Describe("Repository", func() { metricsMock.EXPECT().WithContext(context.TODO()).Return(metricsMock) metricsMock.EXPECT().Publish(key.TableName(), MetricNameSavedItemsCount, float64(1)).Return(nil) - logMock.EXPECT().WithContext(context.TODO()).Return(logMock) err := repository.SaveItem(key, user) Expect(err).To(BeNil()) }) diff --git a/repository_update_test.go b/repository_update_test.go index 3579957..c9443da 100644 --- a/repository_update_test.go +++ b/repository_update_test.go @@ -4,10 +4,12 @@ import ( "context" "errors" - . "github.com/adjoeio/djoemo" - "github.com/adjoeio/djoemo/mock" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/dynamodb" "go.uber.org/mock/gomock" + + . "github.com/adjoeio/djoemo" + "github.com/adjoeio/djoemo/mock" ) var _ = Describe("Repository", func() { @@ -245,88 +247,149 @@ var _ = Describe("Repository", func() { }) }) - Describe("Log", func() { - It("should log with extra fields if log is supported", func() { + Describe("ConditionalUpdateWithUpdateExpressions", func() { + It("should update an item if the condition is met", func() { + hashKey := "uuid" + oldUsername := "username" + newUsername := "username2" + key := Key().WithTableName(UserTableName). WithHashKeyName("UUID"). - WithHashKey("uuid"). - WithRangeKeyName("email"). - WithRangeKey("mail@adjoe.io") - err := errors.New("failed to update item") - dMock.Should().Update( - dMock.WithTable(key.TableName()), - dMock.WithError(err), - ).Exec() + WithHashKey(hashKey) updates := map[string]interface{}{ - "UserName": "name2", - "TraceID": "name4", + "UserName": newUsername, } - repository.WithLog(logMock) - logMock.EXPECT().WithFields(map[string]interface{}{"TableName": key.TableName()}).Return(logMock) - logMock.EXPECT().WithContext(context.TODO()).Return(logMock) - logMock.EXPECT().Error(err.Error()) - ret := repository.Update(Set, key, updates) - Expect(ret).To(BeEquivalentTo(err)) - - }) - }) - - Describe("Metrics", func() { - It("should publish metrics if metric is supported", func() { - key := Key().WithTableName(UserTableName). - WithHashKeyName("UUID"). - WithHashKey("uuid") - dMock.Should().Update( - dMock.WithTable(key.TableName()), - dMock.WithMatch( - mock.InputExpect(). - FieldEq("UserName", "name2").FieldEq("TraceID", "name4"), + dMock.WithTable(UserTableName), + dMock.WithHash( + "UUID", hashKey, ), + dMock.WithUpdateItemInput("SET UserName = ?", newUsername, mock.WithCondition("(UserName = ?) AND (UUID = ?)", oldUsername, hashKey)), ).Exec() - updates := map[string]interface{}{ - "UserName": "name2", - "TraceID": "name4", - } + updated, err := repository.ConditionalUpdateWithUpdateExpressions(context.Background(), key, UpdateExpressions{ + Set: updates, + }, WithCondition("UserName = ?", oldUsername), WithCondition("UUID = ?", hashKey)) - repository.WithMetrics(metricsMock) - metricsMock.EXPECT().WithContext(context.TODO()).Return(metricsMock) - metricsMock.EXPECT().Publish(key.TableName(), MetricNameUpdatedItemsCount, float64(1)).Return(nil) - err := repository.Update(SetSet, key, updates) + Expect(updated).To(BeTrue()) Expect(err).To(BeNil()) }) - It("should not affect update and log error if publish failed", func() { + It("should skip an update if the condition were not met", func() { + hashKey := "uuid" + oldUsername := "username" + newUsername := "username2" + key := Key().WithTableName(UserTableName). WithHashKeyName("UUID"). - WithHashKey("uuid") + WithHashKey(hashKey) + + updates := map[string]interface{}{ + "UserName": newUsername, + } dMock.Should().Update( - dMock.WithTable(key.TableName()), - dMock.WithMatch( - mock.InputExpect(). - FieldEq("UserName", "name2").FieldEq("TraceID", "name4"), + dMock.WithTable(UserTableName), + dMock.WithHash( + "UUID", hashKey, ), + dMock.WithUpdateItemInput("SET UserName = ?", newUsername, mock.WithCondition("(UserName = ?)", oldUsername)), + dMock.WithError(awserr.New(dynamodb.ErrCodeConditionalCheckFailedException, "", errors.New(dynamodb.ErrCodeConditionalCheckFailedException))), ).Exec() - updates := map[string]interface{}{ - "UserName": "name2", - "TraceID": "name4", - } + updated, err := repository.ConditionalUpdateWithUpdateExpressions(context.Background(), key, UpdateExpressions{ + Set: updates, + }, WithCondition("UserName = ?", oldUsername)) - repository.WithMetrics(metricsMock) - repository.WithLog(logMock) - metricsMock.EXPECT().WithContext(context.TODO()).Return(metricsMock) - metricsMock.EXPECT().Publish(key.TableName(), MetricNameUpdatedItemsCount, float64(1)). - Return(errors.New("failed to publish")) - logMock.EXPECT().WithFields(map[string]interface{}{"TableName": key.TableName()}).Return(logMock) - logMock.EXPECT().WithContext(context.TODO()).Return(logMock) - logMock.EXPECT().Error("failed to publish") - err := repository.Update(SetSet, key, updates) + Expect(updated).To(BeFalse()) Expect(err).To(BeNil()) }) + + Describe("Log", func() { + It("should log with extra fields if log is supported", func() { + key := Key().WithTableName(UserTableName). + WithHashKeyName("UUID"). + WithHashKey("uuid"). + WithRangeKeyName("email"). + WithRangeKey("mail@adjoe.io") + err := errors.New("failed to update item") + dMock.Should().Update( + dMock.WithTable(key.TableName()), + dMock.WithError(err), + ).Exec() + + updates := map[string]interface{}{ + "UserName": "name2", + "TraceID": "name4", + } + + repository.WithLog(logMock) + logMock.EXPECT().WithFields(map[string]interface{}{"TableName": key.TableName()}).Return(logMock) + logMock.EXPECT().WithContext(context.TODO()).Return(logMock) + logMock.EXPECT().Error(err.Error()) + ret := repository.Update(Set, key, updates) + Expect(ret).To(BeEquivalentTo(err)) + + }) + }) + + Describe("Metrics", func() { + It("should publish metrics if metric is supported", func() { + key := Key().WithTableName(UserTableName). + WithHashKeyName("UUID"). + WithHashKey("uuid") + + dMock.Should().Update( + dMock.WithTable(key.TableName()), + dMock.WithMatch( + mock.InputExpect(). + FieldEq("UserName", "name2").FieldEq("TraceID", "name4"), + ), + ).Exec() + + updates := map[string]interface{}{ + "UserName": "name2", + "TraceID": "name4", + } + + repository.WithMetrics(metricsMock) + metricsMock.EXPECT().WithContext(context.TODO()).Return(metricsMock) + metricsMock.EXPECT().Publish(key.TableName(), MetricNameUpdatedItemsCount, float64(1)).Return(nil) + err := repository.Update(SetSet, key, updates) + Expect(err).To(BeNil()) + }) + + It("should not affect update and log error if publish failed", func() { + key := Key().WithTableName(UserTableName). + WithHashKeyName("UUID"). + WithHashKey("uuid") + + dMock.Should().Update( + dMock.WithTable(key.TableName()), + dMock.WithMatch( + mock.InputExpect(). + FieldEq("UserName", "name2").FieldEq("TraceID", "name4"), + ), + ).Exec() + + updates := map[string]interface{}{ + "UserName": "name2", + "TraceID": "name4", + } + + repository.WithMetrics(metricsMock) + repository.WithLog(logMock) + metricsMock.EXPECT().WithContext(context.TODO()).Return(metricsMock) + metricsMock.EXPECT().Publish(key.TableName(), MetricNameUpdatedItemsCount, float64(1)). + Return(errors.New("failed to publish")) + logMock.EXPECT().WithFields(map[string]interface{}{"TableName": key.TableName()}).Return(logMock) + logMock.EXPECT().WithContext(context.TODO()).Return(logMock) + logMock.EXPECT().Error("failed to publish") + err := repository.Update(SetSet, key, updates) + Expect(err).To(BeNil()) + }) + }) }) }) diff --git a/update_option.go b/update_option.go new file mode 100644 index 0000000..e7fa12a --- /dev/null +++ b/update_option.go @@ -0,0 +1,15 @@ +package djoemo + +import "github.com/guregu/dynamo" + +type UpdateOption func(update *dynamo.Update) + +func WithCondition(conditionExpression string, conditionArgs ...any) func(update *dynamo.Update) { + return func(update *dynamo.Update) { + if update == nil { + update = &dynamo.Update{} + } + + update.If(conditionExpression, conditionArgs...) + } +}