diff --git a/go/base/context.go b/go/base/context.go index c6ccb800c..087265815 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -252,6 +252,14 @@ type MigrationContext struct { TriggerSuffix string Triggers []mysql.Trigger + // RowFilterWhereClause is an optional WHERE clause to filter rows during migration. + // Only rows matching this condition will be copied to the ghost table. + // This enables using gh-ost for data deletion/purging. + // Example: "created_at >= '2024-01-01'" keeps only rows from 2024 onwards + RowFilterWhereClause string + // RowFilter is the parsed filter for evaluating binlog events + RowFilter *sql.RowFilter + recentBinlogCoordinates mysql.BinlogCoordinates BinlogSyncerMaxReconnectAttempts int diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index fae519680..7bd1cc68b 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -69,6 +69,7 @@ func main() { flag.StringVar(&migrationContext.DatabaseName, "database", "", "database name (mandatory)") flag.StringVar(&migrationContext.OriginalTableName, "table", "", "table name (mandatory)") flag.StringVar(&migrationContext.AlterStatement, "alter", "", "alter statement (mandatory)") + flag.StringVar(&migrationContext.RowFilterWhereClause, "where", "", "WHERE clause to filter rows during copy. Only rows matching this condition are kept in the ghost table. Useful for data purging. Example: --where=\"created_at >= '2024-01-01'\" to keep only recent data") flag.BoolVar(&migrationContext.AttemptInstantDDL, "attempt-instant-ddl", false, "Attempt to use instant DDL for this migration first") storageEngine := flag.String("storage-engine", "innodb", "Specify table storage engine (default: 'innodb'). When 'rocksdb': the session transaction isolation level is changed from REPEATABLE_READ to READ_COMMITTED.") diff --git a/go/logic/applier.go b/go/logic/applier.go index 68d11171b..8742e153d 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -880,7 +880,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected startTime := time.Now() chunkSize = atomic.LoadInt64(&this.migrationContext.ChunkSize) - query, explodedArgs, err := sql.BuildRangeInsertPreparedQuery( + query, explodedArgs, err := sql.BuildRangeInsertPreparedQueryWithFilter( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.GetGhostTableName(), @@ -894,6 +894,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.IsTransactionalTable(), // TODO: Don't hardcode this strings.HasPrefix(this.migrationContext.ApplierMySQLVersion, "8."), + this.migrationContext.RowFilterWhereClause, ) if err != nil { return chunkSize, rowsAffected, duration, err @@ -1449,19 +1450,64 @@ func (this *Applier) updateModifiesUniqueKeyColumns(dmlEvent *binlog.BinlogDMLEv // buildDMLEventQuery creates a query to operate on the ghost table, based on an intercepted binlog // event entry on the original table. func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlBuildResult { + // Check if we have a row filter for data purging + rowFilter := this.migrationContext.RowFilter + hasFilter := rowFilter != nil && !rowFilter.IsEmpty() + switch dmlEvent.DML { case binlog.DeleteDML: { + // For DELETE: If we have a filter, the row either matched (and was copied) or didn't match + // (and wasn't copied). Either way, we issue the DELETE - it's idempotent. + // However, if the row doesn't match the filter, it was never in the ghost table, + // so we can skip the DELETE to avoid unnecessary work. + if hasFilter { + oldMatches := rowFilter.Matches(dmlEvent.WhereColumnValues.AbstractValues()) + if !oldMatches { + // Row was never copied to ghost table, skip DELETE + return []*dmlBuildResult{} + } + } query, uniqueKeyArgs, err := this.dmlDeleteQueryBuilder.BuildQuery(dmlEvent.WhereColumnValues.AbstractValues()) return []*dmlBuildResult{newDmlBuildResult(query, uniqueKeyArgs, -1, err)} } case binlog.InsertDML: { + // For INSERT: Only insert if the row matches the filter (or no filter) + if hasFilter { + newMatches := rowFilter.Matches(dmlEvent.NewColumnValues.AbstractValues()) + if !newMatches { + // Row doesn't match filter - don't insert (effectively deleted) + return []*dmlBuildResult{} + } + } query, sharedArgs, err := this.dmlInsertQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues()) return []*dmlBuildResult{newDmlBuildResult(query, sharedArgs, 1, err)} } case binlog.UpdateDML: { + // For UPDATE with filter, check if the row's filter status changed + if hasFilter { + oldMatches := rowFilter.Matches(dmlEvent.WhereColumnValues.AbstractValues()) + newMatches := rowFilter.Matches(dmlEvent.NewColumnValues.AbstractValues()) + + if !oldMatches && !newMatches { + // Row never matched filter - no-op + return []*dmlBuildResult{} + } + if oldMatches && !newMatches { + // Row used to match but no longer does - treat as DELETE + query, uniqueKeyArgs, err := this.dmlDeleteQueryBuilder.BuildQuery(dmlEvent.WhereColumnValues.AbstractValues()) + return []*dmlBuildResult{newDmlBuildResult(query, uniqueKeyArgs, -1, err)} + } + if !oldMatches && newMatches { + // Row now matches but didn't before - treat as INSERT + query, sharedArgs, err := this.dmlInsertQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues()) + return []*dmlBuildResult{newDmlBuildResult(query, sharedArgs, 1, err)} + } + // Both old and new match - proceed with normal UPDATE below + } + if _, isModified := this.updateModifiesUniqueKeyColumns(dmlEvent); isModified { results := make([]*dmlBuildResult, 0, 2) dmlEvent.DML = binlog.DeleteDML @@ -1519,6 +1565,14 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) } } + // If all events were filtered out (e.g., by --where filter), nothing to execute + if len(buildResults) == 0 { + if err := tx.Commit(); err != nil { + return err + } + return nil + } + // We batch together the DML queries into multi-statements to minimize network trips. // We have to use the raw driver connection to access the rows affected // for each statement in the multi-statement. diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 7255fc757..b3bbe564c 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -414,6 +414,16 @@ func (this *Migrator) Migrate() (err error) { return err } + // Initialize row filter for data purging if a WHERE clause was provided + if this.migrationContext.RowFilterWhereClause != "" { + rowFilter, err := sql.NewRowFilter(this.migrationContext.RowFilterWhereClause, this.migrationContext.OriginalTableColumns) + if err != nil { + return this.migrationContext.Log.Errorf("Failed to parse --where clause: %+v", err) + } + this.migrationContext.RowFilter = rowFilter + this.migrationContext.Log.Infof("Row filter enabled: only rows matching '%s' will be migrated", this.migrationContext.RowFilterWhereClause) + } + // We can prepare some of the queries on the applier if err := this.applier.prepareQueries(); err != nil { return err diff --git a/go/sql/builder.go b/go/sql/builder.go index 61dd9706f..dfc4ad680 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -261,6 +261,12 @@ func BuildRangePreparedComparison(columns *ColumnList, args []interface{}, compa } func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool, noWait bool) (result string, explodedArgs []interface{}, err error) { + return BuildRangeInsertQueryWithFilter(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait, "") +} + +// BuildRangeInsertQueryWithFilter builds an INSERT...SELECT query with an optional row filter WHERE clause. +// The rowFilterWhereClause parameter allows filtering rows during copy (for data purging). +func BuildRangeInsertQueryWithFilter(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartValues, rangeEndValues []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool, noWait bool, rowFilterWhereClause string) (result string, explodedArgs []interface{}, err error) { if len(sharedColumns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 shared columns in BuildRangeInsertQuery") } @@ -303,6 +309,13 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin return "", explodedArgs, err } explodedArgs = append(explodedArgs, rangeExplodedArgs...) + + // Build optional row filter clause for data purging + rowFilterClause := "" + if rowFilterWhereClause != "" { + rowFilterClause = fmt.Sprintf("and (%s)", rowFilterWhereClause) + } + result = fmt.Sprintf(` insert /* gh-ost %s.%s */ ignore into @@ -314,19 +327,24 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin %s.%s force index (%s) where - (%s and %s) + (%s and %s) %s %s )`, databaseName, originalTableName, databaseName, ghostTableName, mappedSharedColumnsListing, sharedColumnsListing, databaseName, originalTableName, uniqueKey, - rangeStartComparison, rangeEndComparison, transactionalClause) + rangeStartComparison, rangeEndComparison, rowFilterClause, transactionalClause) return result, explodedArgs, nil } func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool, noWait bool) (result string, explodedArgs []interface{}, err error) { + return BuildRangeInsertPreparedQueryWithFilter(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait, "") +} + +// BuildRangeInsertPreparedQueryWithFilter builds a prepared INSERT...SELECT query with an optional row filter. +func BuildRangeInsertPreparedQueryWithFilter(databaseName, originalTableName, ghostTableName string, sharedColumns []string, mappedSharedColumns []string, uniqueKey string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool, noWait bool, rowFilterWhereClause string) (result string, explodedArgs []interface{}, err error) { rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns) rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns) - return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait) + return BuildRangeInsertQueryWithFilter(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait, rowFilterWhereClause) } func BuildUniqueKeyRangeEndPreparedQueryViaOffset(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { diff --git a/go/sql/filter.go b/go/sql/filter.go new file mode 100644 index 000000000..ec57092c3 --- /dev/null +++ b/go/sql/filter.go @@ -0,0 +1,501 @@ +/* + Copyright 2025 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package sql + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" +) + +// ComparisonOperator represents a comparison operator in a filter condition +type ComparisonOperator string + +const ( + OpEquals ComparisonOperator = "=" + OpNotEquals ComparisonOperator = "!=" + OpNotEqualsAlt ComparisonOperator = "<>" + OpLessThan ComparisonOperator = "<" + OpLessThanOrEquals ComparisonOperator = "<=" + OpGreaterThan ComparisonOperator = ">" + OpGreaterThanOrEquals ComparisonOperator = ">=" + OpIsNull ComparisonOperator = "IS NULL" + OpIsNotNull ComparisonOperator = "IS NOT NULL" + OpLike ComparisonOperator = "LIKE" + OpIn ComparisonOperator = "IN" +) + +// FilterCondition represents a single condition in a WHERE clause +type FilterCondition struct { + Column string + Operator ComparisonOperator + Value interface{} // string, int64, float64, time.Time, []interface{} for IN, or nil for IS NULL +} + +// LogicalOperator represents AND/OR +type LogicalOperator string + +const ( + LogicalAnd LogicalOperator = "AND" + LogicalOr LogicalOperator = "OR" +) + +// RowFilter represents a parsed WHERE clause that can evaluate rows +type RowFilter struct { + WhereClause string + Conditions []FilterCondition + Operators []LogicalOperator // len = len(Conditions) - 1 + columnMap map[string]int // column name -> ordinal position +} + +// NewRowFilter parses a WHERE clause and creates a RowFilter +func NewRowFilter(whereClause string, columns *ColumnList) (*RowFilter, error) { + if whereClause == "" { + return nil, nil + } + + filter := &RowFilter{ + WhereClause: whereClause, + Conditions: []FilterCondition{}, + Operators: []LogicalOperator{}, + columnMap: make(map[string]int), + } + + // Build column name -> ordinal map + if columns != nil { + for i, col := range columns.Columns() { + filter.columnMap[strings.ToLower(col.Name)] = i + } + } + + // Parse the WHERE clause + if err := filter.parse(whereClause); err != nil { + return nil, err + } + + return filter, nil +} + +// parse parses the WHERE clause into conditions +func (f *RowFilter) parse(whereClause string) error { + // Normalize whitespace + whereClause = strings.TrimSpace(whereClause) + if whereClause == "" { + return nil + } + + // Split by AND/OR (simple parsing - doesn't handle nested parentheses) + // This regex captures AND/OR as delimiters while preserving them + splitRegex := regexp.MustCompile(`(?i)\s+(AND|OR)\s+`) + parts := splitRegex.Split(whereClause, -1) + operators := splitRegex.FindAllStringSubmatch(whereClause, -1) + + for i, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + condition, err := f.parseCondition(part) + if err != nil { + return fmt.Errorf("failed to parse condition '%s': %w", part, err) + } + f.Conditions = append(f.Conditions, condition) + + if i < len(operators) { + op := strings.ToUpper(strings.TrimSpace(operators[i][1])) + if op == "AND" { + f.Operators = append(f.Operators, LogicalAnd) + } else { + f.Operators = append(f.Operators, LogicalOr) + } + } + } + + return nil +} + +// parseCondition parses a single condition like "column >= 'value'" +func (f *RowFilter) parseCondition(condition string) (FilterCondition, error) { + condition = strings.TrimSpace(condition) + + // Remove surrounding parentheses if present + for strings.HasPrefix(condition, "(") && strings.HasSuffix(condition, ")") { + condition = strings.TrimPrefix(condition, "(") + condition = strings.TrimSuffix(condition, ")") + condition = strings.TrimSpace(condition) + } + + // Check for IS NULL / IS NOT NULL + isNullRegex := regexp.MustCompile(`(?i)^(\w+)\s+IS\s+NULL$`) + isNotNullRegex := regexp.MustCompile(`(?i)^(\w+)\s+IS\s+NOT\s+NULL$`) + + if match := isNullRegex.FindStringSubmatch(condition); match != nil { + return FilterCondition{ + Column: strings.ToLower(match[1]), + Operator: OpIsNull, + Value: nil, + }, nil + } + + if match := isNotNullRegex.FindStringSubmatch(condition); match != nil { + return FilterCondition{ + Column: strings.ToLower(match[1]), + Operator: OpIsNotNull, + Value: nil, + }, nil + } + + // Check for IN clause + inRegex := regexp.MustCompile(`(?i)^(\w+)\s+IN\s*\((.+)\)$`) + if match := inRegex.FindStringSubmatch(condition); match != nil { + column := strings.ToLower(match[1]) + valuesStr := match[2] + values, err := f.parseInValues(valuesStr) + if err != nil { + return FilterCondition{}, err + } + return FilterCondition{ + Column: column, + Operator: OpIn, + Value: values, + }, nil + } + + // Standard comparison operators (order matters - check multi-char first) + operators := []struct { + pattern string + op ComparisonOperator + }{ + {"<>", OpNotEqualsAlt}, + {"!=", OpNotEquals}, + {">=", OpGreaterThanOrEquals}, + {"<=", OpLessThanOrEquals}, + {">", OpGreaterThan}, + {"<", OpLessThan}, + {"=", OpEquals}, + } + + for _, opDef := range operators { + idx := strings.Index(condition, opDef.pattern) + if idx > 0 { + column := strings.TrimSpace(condition[:idx]) + valueStr := strings.TrimSpace(condition[idx+len(opDef.pattern):]) + + // Remove backticks from column name + column = strings.Trim(column, "`") + column = strings.ToLower(column) + + value, err := f.parseValue(valueStr) + if err != nil { + return FilterCondition{}, err + } + + return FilterCondition{ + Column: column, + Operator: opDef.op, + Value: value, + }, nil + } + } + + return FilterCondition{}, fmt.Errorf("could not parse condition: %s", condition) +} + +// parseValue parses a value string into an appropriate Go type +func (f *RowFilter) parseValue(valueStr string) (interface{}, error) { + valueStr = strings.TrimSpace(valueStr) + + // String literal (single or double quoted) + if (strings.HasPrefix(valueStr, "'") && strings.HasSuffix(valueStr, "'")) || + (strings.HasPrefix(valueStr, "\"") && strings.HasSuffix(valueStr, "\"")) { + unquoted := valueStr[1 : len(valueStr)-1] + // Try to parse as date/datetime + if t, err := time.Parse("2006-01-02 15:04:05", unquoted); err == nil { + return t, nil + } + if t, err := time.Parse("2006-01-02", unquoted); err == nil { + return t, nil + } + return unquoted, nil + } + + // NULL + if strings.ToUpper(valueStr) == "NULL" { + return nil, nil + } + + // Integer + if i, err := strconv.ParseInt(valueStr, 10, 64); err == nil { + return i, nil + } + + // Float + if f, err := strconv.ParseFloat(valueStr, 64); err == nil { + return f, nil + } + + // Boolean + if strings.ToUpper(valueStr) == "TRUE" { + return true, nil + } + if strings.ToUpper(valueStr) == "FALSE" { + return false, nil + } + + return valueStr, nil +} + +// parseInValues parses the values inside an IN clause +func (f *RowFilter) parseInValues(valuesStr string) ([]interface{}, error) { + // Simple comma split (doesn't handle commas inside strings) + parts := strings.Split(valuesStr, ",") + values := make([]interface{}, 0, len(parts)) + for _, part := range parts { + v, err := f.parseValue(strings.TrimSpace(part)) + if err != nil { + return nil, err + } + values = append(values, v) + } + return values, nil +} + +// SetColumnMap updates the column name to ordinal mapping +func (f *RowFilter) SetColumnMap(columns *ColumnList) { + f.columnMap = make(map[string]int) + if columns != nil { + for i, col := range columns.Columns() { + f.columnMap[strings.ToLower(col.Name)] = i + } + } +} + +// Matches evaluates whether a row (as a slice of values) matches the filter +func (f *RowFilter) Matches(rowValues []interface{}) bool { + if len(f.Conditions) == 0 { + return true + } + + result := f.evaluateCondition(f.Conditions[0], rowValues) + + for i := 1; i < len(f.Conditions); i++ { + condResult := f.evaluateCondition(f.Conditions[i], rowValues) + + if i-1 < len(f.Operators) { + switch f.Operators[i-1] { + case LogicalAnd: + result = result && condResult + case LogicalOr: + result = result || condResult + } + } + } + + return result +} + +// evaluateCondition evaluates a single condition against row values +func (f *RowFilter) evaluateCondition(cond FilterCondition, rowValues []interface{}) bool { + ordinal, exists := f.columnMap[cond.Column] + if !exists || ordinal >= len(rowValues) { + // Column not found - default to not matching for safety + return false + } + + rowValue := rowValues[ordinal] + + switch cond.Operator { + case OpIsNull: + return rowValue == nil + case OpIsNotNull: + return rowValue != nil + case OpIn: + return f.evaluateIn(rowValue, cond.Value.([]interface{})) + default: + return f.compare(rowValue, cond.Value, cond.Operator) + } +} + +// evaluateIn checks if rowValue is in the list of values +func (f *RowFilter) evaluateIn(rowValue interface{}, values []interface{}) bool { + for _, v := range values { + if f.equals(rowValue, v) { + return true + } + } + return false +} + +// equals checks if two values are equal (with type coercion) +func (f *RowFilter) equals(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // Convert both to strings for comparison (simple but handles most cases) + aStr := fmt.Sprintf("%v", a) + bStr := fmt.Sprintf("%v", b) + return aStr == bStr +} + +// compare compares two values using the given operator +func (f *RowFilter) compare(rowValue, filterValue interface{}, op ComparisonOperator) bool { + if rowValue == nil || filterValue == nil { + if op == OpEquals || op == OpNotEquals || op == OpNotEqualsAlt { + isEqual := (rowValue == nil && filterValue == nil) + if op == OpEquals { + return isEqual + } + return !isEqual + } + return false + } + + // Try to compare as times first + rowTime := f.toTime(rowValue) + filterTime := f.toTime(filterValue) + if rowTime != nil && filterTime != nil { + return f.compareTimes(*rowTime, *filterTime, op) + } + + // Try to compare as numbers + rowNum, rowIsNum := f.toFloat64(rowValue) + filterNum, filterIsNum := f.toFloat64(filterValue) + if rowIsNum && filterIsNum { + return f.compareNumbers(rowNum, filterNum, op) + } + + // Fall back to string comparison + rowStr := fmt.Sprintf("%v", rowValue) + filterStr := fmt.Sprintf("%v", filterValue) + return f.compareStrings(rowStr, filterStr, op) +} + +// toTime attempts to convert a value to time.Time +func (f *RowFilter) toTime(v interface{}) *time.Time { + switch t := v.(type) { + case time.Time: + return &t + case *time.Time: + return t + case string: + if parsed, err := time.Parse("2006-01-02 15:04:05", t); err == nil { + return &parsed + } + if parsed, err := time.Parse("2006-01-02", t); err == nil { + return &parsed + } + } + return nil +} + +// toFloat64 attempts to convert a value to float64 +func (f *RowFilter) toFloat64(v interface{}) (float64, bool) { + switch n := v.(type) { + case int: + return float64(n), true + case int8: + return float64(n), true + case int16: + return float64(n), true + case int32: + return float64(n), true + case int64: + return float64(n), true + case uint: + return float64(n), true + case uint8: + return float64(n), true + case uint16: + return float64(n), true + case uint32: + return float64(n), true + case uint64: + return float64(n), true + case float32: + return float64(n), true + case float64: + return n, true + case string: + if f, err := strconv.ParseFloat(n, 64); err == nil { + return f, true + } + } + return 0, false +} + +// compareTimes compares two times +func (f *RowFilter) compareTimes(a, b time.Time, op ComparisonOperator) bool { + switch op { + case OpEquals: + return a.Equal(b) + case OpNotEquals, OpNotEqualsAlt: + return !a.Equal(b) + case OpLessThan: + return a.Before(b) + case OpLessThanOrEquals: + return a.Before(b) || a.Equal(b) + case OpGreaterThan: + return a.After(b) + case OpGreaterThanOrEquals: + return a.After(b) || a.Equal(b) + } + return false +} + +// compareNumbers compares two numbers +func (f *RowFilter) compareNumbers(a, b float64, op ComparisonOperator) bool { + switch op { + case OpEquals: + return a == b + case OpNotEquals, OpNotEqualsAlt: + return a != b + case OpLessThan: + return a < b + case OpLessThanOrEquals: + return a <= b + case OpGreaterThan: + return a > b + case OpGreaterThanOrEquals: + return a >= b + } + return false +} + +// compareStrings compares two strings +func (f *RowFilter) compareStrings(a, b string, op ComparisonOperator) bool { + switch op { + case OpEquals: + return a == b + case OpNotEquals, OpNotEqualsAlt: + return a != b + case OpLessThan: + return a < b + case OpLessThanOrEquals: + return a <= b + case OpGreaterThan: + return a > b + case OpGreaterThanOrEquals: + return a >= b + } + return false +} + +// GetWhereClause returns the original WHERE clause +func (f *RowFilter) GetWhereClause() string { + return f.WhereClause +} + +// IsEmpty returns true if the filter has no conditions +func (f *RowFilter) IsEmpty() bool { + return len(f.Conditions) == 0 +} diff --git a/go/sql/filter_test.go b/go/sql/filter_test.go new file mode 100644 index 000000000..fd4579261 --- /dev/null +++ b/go/sql/filter_test.go @@ -0,0 +1,246 @@ +/* + Copyright 2025 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package sql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewRowFilter_Empty(t *testing.T) { + filter, err := NewRowFilter("", nil) + require.NoError(t, err) + require.Nil(t, filter) +} + +func TestNewRowFilter_SimpleEquals(t *testing.T) { + columns := NewColumnList([]string{"id", "name", "status"}) + filter, err := NewRowFilter("status = 'active'", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + require.Equal(t, "status", filter.Conditions[0].Column) + require.Equal(t, OpEquals, filter.Conditions[0].Operator) + require.Equal(t, "active", filter.Conditions[0].Value) +} + +func TestNewRowFilter_NumericComparison(t *testing.T) { + columns := NewColumnList([]string{"id", "age", "score"}) + filter, err := NewRowFilter("age >= 18", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + require.Equal(t, "age", filter.Conditions[0].Column) + require.Equal(t, OpGreaterThanOrEquals, filter.Conditions[0].Operator) + require.Equal(t, int64(18), filter.Conditions[0].Value) +} + +func TestNewRowFilter_DateComparison(t *testing.T) { + columns := NewColumnList([]string{"id", "created_at"}) + filter, err := NewRowFilter("created_at >= '2024-01-01'", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + require.Equal(t, "created_at", filter.Conditions[0].Column) + require.Equal(t, OpGreaterThanOrEquals, filter.Conditions[0].Operator) + + expectedDate, _ := time.Parse("2006-01-02", "2024-01-01") + require.Equal(t, expectedDate, filter.Conditions[0].Value) +} + +func TestNewRowFilter_AndConditions(t *testing.T) { + columns := NewColumnList([]string{"id", "status", "age"}) + filter, err := NewRowFilter("status = 'active' AND age >= 18", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 2) + require.Len(t, filter.Operators, 1) + require.Equal(t, LogicalAnd, filter.Operators[0]) +} + +func TestNewRowFilter_OrConditions(t *testing.T) { + columns := NewColumnList([]string{"id", "status"}) + filter, err := NewRowFilter("status = 'active' OR status = 'pending'", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 2) + require.Len(t, filter.Operators, 1) + require.Equal(t, LogicalOr, filter.Operators[0]) +} + +func TestNewRowFilter_IsNull(t *testing.T) { + columns := NewColumnList([]string{"id", "deleted_at"}) + filter, err := NewRowFilter("deleted_at IS NULL", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + require.Equal(t, "deleted_at", filter.Conditions[0].Column) + require.Equal(t, OpIsNull, filter.Conditions[0].Operator) +} + +func TestNewRowFilter_IsNotNull(t *testing.T) { + columns := NewColumnList([]string{"id", "email"}) + filter, err := NewRowFilter("email IS NOT NULL", columns) + require.NoError(t, err) + require.NotNil(t, filter) + require.Len(t, filter.Conditions, 1) + require.Equal(t, "email", filter.Conditions[0].Column) + require.Equal(t, OpIsNotNull, filter.Conditions[0].Operator) +} + +func TestRowFilter_Matches_SimpleEquals(t *testing.T) { + columns := NewColumnList([]string{"id", "status"}) + filter, err := NewRowFilter("status = 'active'", columns) + require.NoError(t, err) + + // Row matches + require.True(t, filter.Matches([]interface{}{1, "active"})) + + // Row doesn't match + require.False(t, filter.Matches([]interface{}{1, "inactive"})) +} + +func TestRowFilter_Matches_NumericGreaterThan(t *testing.T) { + columns := NewColumnList([]string{"id", "age"}) + filter, err := NewRowFilter("age >= 18", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, int64(18)})) + require.True(t, filter.Matches([]interface{}{1, int64(25)})) + require.False(t, filter.Matches([]interface{}{1, int64(17)})) +} + +func TestRowFilter_Matches_DateComparison(t *testing.T) { + columns := NewColumnList([]string{"id", "created_at"}) + filter, err := NewRowFilter("created_at >= '2024-01-01'", columns) + require.NoError(t, err) + + date2024, _ := time.Parse("2006-01-02", "2024-06-15") + date2023, _ := time.Parse("2006-01-02", "2023-06-15") + + require.True(t, filter.Matches([]interface{}{1, date2024})) + require.False(t, filter.Matches([]interface{}{1, date2023})) +} + +func TestRowFilter_Matches_AndConditions(t *testing.T) { + columns := NewColumnList([]string{"id", "status", "age"}) + filter, err := NewRowFilter("status = 'active' AND age >= 18", columns) + require.NoError(t, err) + + // Both conditions match + require.True(t, filter.Matches([]interface{}{1, "active", int64(25)})) + + // Only status matches + require.False(t, filter.Matches([]interface{}{1, "active", int64(15)})) + + // Only age matches + require.False(t, filter.Matches([]interface{}{1, "inactive", int64(25)})) + + // Neither matches + require.False(t, filter.Matches([]interface{}{1, "inactive", int64(15)})) +} + +func TestRowFilter_Matches_OrConditions(t *testing.T) { + columns := NewColumnList([]string{"id", "status"}) + filter, err := NewRowFilter("status = 'active' OR status = 'pending'", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, "active"})) + require.True(t, filter.Matches([]interface{}{1, "pending"})) + require.False(t, filter.Matches([]interface{}{1, "deleted"})) +} + +func TestRowFilter_Matches_IsNull(t *testing.T) { + columns := NewColumnList([]string{"id", "deleted_at"}) + filter, err := NewRowFilter("deleted_at IS NULL", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, nil})) + require.False(t, filter.Matches([]interface{}{1, time.Now()})) +} + +func TestRowFilter_Matches_IsNotNull(t *testing.T) { + columns := NewColumnList([]string{"id", "email"}) + filter, err := NewRowFilter("email IS NOT NULL", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, "test@example.com"})) + require.False(t, filter.Matches([]interface{}{1, nil})) +} + +func TestRowFilter_Matches_NotEquals(t *testing.T) { + columns := NewColumnList([]string{"id", "status"}) + filter, err := NewRowFilter("status != 'deleted'", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, "active"})) + require.False(t, filter.Matches([]interface{}{1, "deleted"})) +} + +func TestRowFilter_Matches_LessThan(t *testing.T) { + columns := NewColumnList([]string{"id", "priority"}) + filter, err := NewRowFilter("priority < 5", columns) + require.NoError(t, err) + + require.True(t, filter.Matches([]interface{}{1, int64(3)})) + require.False(t, filter.Matches([]interface{}{1, int64(5)})) + require.False(t, filter.Matches([]interface{}{1, int64(7)})) +} + +func TestRowFilter_IsEmpty(t *testing.T) { + columns := NewColumnList([]string{"id"}) + + filter, _ := NewRowFilter("", columns) + require.Nil(t, filter) + + filter2, _ := NewRowFilter("id = 1", columns) + require.False(t, filter2.IsEmpty()) +} + +func TestRowFilter_Matches_UnknownColumn(t *testing.T) { + columns := NewColumnList([]string{"id", "name"}) + filter, err := NewRowFilter("unknown_column = 'value'", columns) + require.NoError(t, err) + + // Unknown column should result in no match (safe default) + require.False(t, filter.Matches([]interface{}{1, "test"})) +} + +func TestBuildRangeInsertQueryWithFilter(t *testing.T) { + databaseName := "mydb" + originalTableName := "tbl" + ghostTableName := "ghost" + sharedColumns := []string{"id", "name", "created_at"} + uniqueKey := "PRIMARY" + uniqueKeyColumns := NewColumnList([]string{"id"}) + rangeStartArgs := []interface{}{1} + rangeEndArgs := []interface{}{100} + + // Test with filter + query, _, err := BuildRangeInsertPreparedQueryWithFilter( + databaseName, originalTableName, ghostTableName, + sharedColumns, sharedColumns, + uniqueKey, uniqueKeyColumns, + rangeStartArgs, rangeEndArgs, + true, true, false, + "created_at >= '2024-01-01'", + ) + require.NoError(t, err) + require.Contains(t, query, "and (created_at >= '2024-01-01')") + + // Test without filter (backward compatibility) + query2, _, err := BuildRangeInsertPreparedQuery( + databaseName, originalTableName, ghostTableName, + sharedColumns, sharedColumns, + uniqueKey, uniqueKeyColumns, + rangeStartArgs, rangeEndArgs, + true, true, false, + ) + require.NoError(t, err) + require.NotContains(t, query2, "and (created_at") +}