Skip to content

Commit 14f1e99

Browse files
committed
Improve filtering logic, fix integration tests.
1 parent f733c56 commit 14f1e99

9 files changed

Lines changed: 192 additions & 174 deletions

dml_events.go

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,15 @@ func (e *BinlogInsertEvent) NewValues() RowData {
168168
}
169169

170170
func (e *BinlogInsertEvent) AsSQLString(schemaName, tableName string) (string, error) {
171-
if err := verifyValuesHasTheSameLengthAsColumns(e.table, e.newValues); err != nil {
171+
filteredNewValues, err := e.table.FilterGeneratedColumnsOnRowData(e.newValues)
172+
if err != nil {
172173
return "", err
173174
}
174175

175176
query := "INSERT IGNORE INTO " +
176177
QuotedTableNameFromString(schemaName, tableName) +
177-
" (" + strings.Join(quotedColumnNamesForInsert(e.table), ",") + ")" +
178-
" VALUES (" + buildStringListForInsertValues(e.table, e.newValues) + ")"
178+
" (" + strings.Join(quotedColumnNames(e.table), ",") + ")" +
179+
" VALUES (" + buildStringListForValues(e.table, filteredNewValues) + ")"
179180

180181
return query, nil
181182
}
@@ -227,8 +228,8 @@ func (e *BinlogUpdateEvent) AsSQLString(schemaName, tableName string) (string, e
227228
}
228229

229230
query := "UPDATE " + QuotedTableNameFromString(schemaName, tableName) +
230-
" SET " + buildStringMapForSet(e.table.Columns, e.newValues) +
231-
" WHERE " + buildStringMapForWhere(e.table.Columns, e.oldValues)
231+
" SET " + buildStringMapForSet(e.table, e.newValues) +
232+
" WHERE " + buildStringMapForWhere(e.table, e.oldValues)
232233

233234
return query, nil
234235
}
@@ -269,7 +270,7 @@ func (e *BinlogDeleteEvent) AsSQLString(schemaName, tableName string) (string, e
269270
}
270271

271272
query := "DELETE FROM " + QuotedTableNameFromString(schemaName, tableName) +
272-
" WHERE " + buildStringMapForWhere(e.table.Columns, e.oldValues)
273+
" WHERE " + buildStringMapForWhere(e.table, e.oldValues)
273274

274275
return query, nil
275276
}
@@ -281,16 +282,22 @@ func (e *BinlogDeleteEvent) PaginationKey() (string, error) {
281282
func NewBinlogDMLEvents(table *TableSchema, ev *replication.BinlogEvent, pos, resumablePos mysql.Position, query []byte) ([]DMLEvent, error) {
282283
rowsEvent := ev.Event.(*replication.RowsEvent)
283284

284-
for _, row := range rowsEvent.Rows {
285-
if len(row) != len(table.Columns) {
285+
for _, rawRow := range rowsEvent.Rows {
286+
if len(rawRow) != len(table.Columns) {
286287
return nil, fmt.Errorf(
287288
"table %s.%s has %d columns but event has %d columns instead",
288289
table.Schema,
289290
table.Name,
290291
len(table.Columns),
291-
len(row),
292+
len(rawRow),
292293
)
293294
}
295+
296+
row, err := table.FilterGeneratedColumnsOnRowData(rawRow)
297+
if err != nil {
298+
return nil, err
299+
}
300+
294301
for i, col := range table.Columns {
295302
if col.IsUnsigned {
296303
switch v := row[i].(type) {
@@ -323,14 +330,10 @@ func NewBinlogDMLEvents(table *TableSchema, ev *replication.BinlogEvent, pos, re
323330
}
324331
}
325332

326-
func quotedColumnNamesForInsert(table *TableSchema) []string {
327-
cols := []string{}
328-
329-
for _, c := range table.Columns {
330-
if c.IsVirtual {
331-
continue
332-
}
333-
cols = append(cols, QuoteField(c.Name))
333+
func quotedColumnNames(table *TableSchema) []string {
334+
cols := make([]string, 0, len(table.Columns))
335+
for _, name := range table.NonGeneratedColumnNames() {
336+
cols = append(cols, QuoteField(name))
334337
}
335338

336339
return cols
@@ -351,56 +354,62 @@ func verifyValuesHasTheSameLengthAsColumns(table *TableSchema, values ...RowData
351354
return nil
352355
}
353356

354-
func buildStringListForInsertValues(table *TableSchema, values []interface{}) string {
357+
func buildStringListForValues(table *TableSchema, values []interface{}) string {
355358
var buffer []byte
356359

357360
for i, value := range values {
358-
if table.Columns[i].IsVirtual {
361+
if table.IsColumnIndexGenerated(i) {
359362
continue
360363
}
361-
362-
if len(buffer) != 0 {
364+
if len(buffer) > 0 {
363365
buffer = append(buffer, ',')
364366
}
367+
365368
buffer = appendEscapedValue(buffer, value, table.Columns[i])
366369
}
367370

368371
return string(buffer)
369372
}
370373

371-
func buildStringMapForWhere(columns []schema.TableColumn, values []interface{}) string {
374+
func buildStringMapForWhere(table *TableSchema, values []interface{}) string {
372375
var buffer []byte
373376

374377
for i, value := range values {
375-
if i > 0 {
378+
if table.IsColumnIndexGenerated(i) {
379+
continue
380+
}
381+
if len(buffer) > 0 {
376382
buffer = append(buffer, " AND "...)
377383
}
378384

379-
buffer = append(buffer, QuoteField(columns[i].Name)...)
385+
buffer = append(buffer, QuoteField(table.Columns[i].Name)...)
380386

381387
if isNilValue(value) {
382388
// "WHERE value = NULL" will never match rows.
383389
buffer = append(buffer, " IS NULL"...)
384390
} else {
385391
buffer = append(buffer, '=')
386-
buffer = appendEscapedValue(buffer, value, columns[i])
392+
buffer = appendEscapedValue(buffer, value, table.Columns[i])
387393
}
388394
}
389395

390396
return string(buffer)
391397
}
392398

393-
func buildStringMapForSet(columns []schema.TableColumn, values []interface{}) string {
399+
func buildStringMapForSet(table *TableSchema, values []interface{}) string {
394400
var buffer []byte
395401

396402
for i, value := range values {
397-
if i > 0 {
403+
if table.IsColumnIndexGenerated(i) {
404+
continue
405+
}
406+
if len(buffer) > 0 {
398407
buffer = append(buffer, ',')
399408
}
400409

401-
buffer = append(buffer, QuoteField(columns[i].Name)...)
410+
buffer = append(buffer, QuoteField(table.Columns[i].Name)...)
402411
buffer = append(buffer, '=')
403-
buffer = appendEscapedValue(buffer, value, columns[i])
412+
buffer = appendEscapedValue(buffer, value, table.Columns[i])
404413
}
405414

406415
return string(buffer)

row_batch.go

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -64,77 +64,27 @@ func (e *RowBatch) AsSQLQuery(schemaName, tableName string) (string, []interface
6464
return "", nil, err
6565
}
6666

67-
vcm := e.virtualColumnsMap()
68-
valuesStr := "(" + strings.Repeat("?,", e.activeColumnCount(vcm)-1) + "?)"
67+
filteredColumns := e.table.NonGeneratedColumnNames()
68+
69+
valuesStr := "(" + strings.Repeat("?,", len(filteredColumns)-1) + "?)"
6970
valuesStr = strings.Repeat(valuesStr+",", len(e.values)-1) + valuesStr
7071

7172
query := "INSERT IGNORE INTO " +
7273
QuotedTableNameFromString(schemaName, tableName) +
73-
" (" + e.quotedFields(vcm) + ") VALUES " + valuesStr
74-
75-
return query, e.flattenRowData(vcm), nil
76-
}
77-
78-
// virtualColumnsMap returns a map of given columns (by index) -> whether the column is virtual (i.e. generated).
79-
func (e *RowBatch) virtualColumnsMap() map[int]bool {
80-
res := map[int]bool{}
81-
82-
for i, name := range e.columns {
83-
isVirtual := false
84-
for _, c := range e.table.Columns {
85-
if name == c.Name && c.IsVirtual {
86-
isVirtual = true
87-
break
88-
}
89-
}
90-
91-
res[i] = isVirtual
92-
}
93-
94-
return res
95-
}
96-
97-
// activeColumnCount returns the number of active (non-virtual) columns for this RowBatch.
98-
func (e *RowBatch) activeColumnCount(vcm map[int]bool) int {
99-
if vcm == nil {
100-
return len(e.columns)
101-
}
102-
103-
count := 0
104-
for _, isVirtual := range vcm {
105-
if !isVirtual {
106-
count++
107-
}
108-
}
109-
return count
110-
}
111-
112-
// quotedFields returns a string with comma-separated quoted field names for INSERTs.
113-
func (e *RowBatch) quotedFields(vcm map[int]bool) string {
114-
cols := []string{}
115-
for i, name := range e.columns {
116-
if vcm != nil && vcm[i] {
117-
continue
118-
}
119-
cols = append(cols, name)
120-
}
74+
" (" + strings.Join(filteredColumns, ",") + ") VALUES " + valuesStr
12175

122-
return strings.Join(QuoteFields(cols), ",")
76+
return query, e.flattenRowData(), nil
12377
}
12478

125-
// flattenRowData flattens RowData values into a single array for INSERTs.
126-
func (e *RowBatch) flattenRowData(vcm map[int]bool) []interface{} {
127-
rowSize := e.activeColumnCount(vcm)
128-
flattened := make([]interface{}, rowSize*len(e.values))
79+
func (e *RowBatch) flattenRowData() []interface{} {
80+
flattened := make([]interface{}, 0, len(e.values))
12981

130-
for rowIdx, row := range e.values {
131-
i := 0
82+
for _, row := range e.values {
13283
for colIdx, col := range row {
133-
if vcm != nil && vcm[colIdx] {
84+
if e.table.IsColumnIndexGenerated(colIdx) {
13485
continue
13586
}
136-
flattened[rowIdx*rowSize+i] = col
137-
i++
87+
flattened = append(flattened, col)
13888
}
13989
}
14090

0 commit comments

Comments
 (0)