diff --git a/codegen/go.mod b/codegen/go.mod index 3a7a506ed..ad8c19c9b 100644 --- a/codegen/go.mod +++ b/codegen/go.mod @@ -5,11 +5,11 @@ go 1.24.0 toolchain go1.25.0 require ( - golang.org/x/text v0.32.0 - golang.org/x/tools v0.40.0 + golang.org/x/text v0.34.0 + golang.org/x/tools v0.42.0 ) require ( - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.33.0 // indirect golang.org/x/sync v0.19.0 // indirect ) diff --git a/codegen/go.sum b/codegen/go.sum index efb77b22c..791de3672 100644 --- a/codegen/go.sum +++ b/codegen/go.sum @@ -1,10 +1,10 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/docs/doc-site/project/maintainers/ROADMAP.md b/docs/doc-site/project/maintainers/ROADMAP.md index 88d9ccfad..0364a6f60 100644 --- a/docs/doc-site/project/maintainers/ROADMAP.md +++ b/docs/doc-site/project/maintainers/ROADMAP.md @@ -58,7 +58,7 @@ timeline 1. [x] Jan 2026: all go-openapi projects adopts the forked testify 2. [ ] Feb 2026: all go-openapi projects transition to generics -3. [ ] Mar 2026: go-swagger transitions to the forked testify +3. [x] Mar 2026: go-swagger transitions to the forked testify ### What won't come anytime soon diff --git a/docs/doc-site/usage/MIGRATION.md b/docs/doc-site/usage/MIGRATION.md index b6bd8e020..d97f95a02 100644 --- a/docs/doc-site/usage/MIGRATION.md +++ b/docs/doc-site/usage/MIGRATION.md @@ -6,18 +6,127 @@ weight: 20 ## Migration Guide from stretchr/testify v1 -### 1. Update Import Path +This guide covers migrating from `stretchr/testify` to `go-openapi/testify/v2`. +You can use the [automated migration tool](#automated-migration-tool) or migrate [manually](#manual-migration). + +### Automated Migration Tool + +`migrate-testify` automates both the import migration (pass 1) and the generic +upgrade (pass 2). It uses `go/packages` and `go/types` for type-checked, +semantics-preserving transformations. + +#### Installation + +```bash +go install github.com/go-openapi/testify/hack/migrate-testify/v2@latest +``` + +This installs the `migrate-testify` binary into your `$GOBIN`. + +#### Quick Start + +```bash +# Run both passes on the current directory (preview first, then apply) +migrate-testify --all --dry-run . +migrate-testify --all . + +# Or run each pass separately +migrate-testify --migrate . +migrate-testify --upgrade-generics . +``` + +#### Pass 1: Import Migration (`--migrate`) + +Rewrites `stretchr/testify` imports to `go-openapi/testify/v2`: + +```bash +# Dry-run to preview changes +migrate-testify --migrate --dry-run . + +# Apply changes +migrate-testify --migrate . +``` + +This pass handles: +- Import path rewriting (`assert`, `require`, root package) +- Function renames (`EventuallyWithT` to `EventuallyWith`, `NoDirExists` to `DirNotExists`, etc.) +- Type replacement (`PanicTestFunc` to `func()`) +- YAML enable import injection (adds `_ "github.com/go-openapi/testify/v2/enable/yaml"` when `YAMLEq` is used) +- Incompatible import detection (`mock`, `suite`, `http` packages emit warnings with guidance) +- `go.mod` update (drops `stretchr/testify`, adds `go-openapi/testify/v2`) + +#### Pass 2: Generic Upgrade (`--upgrade-generics`) + +Upgrades reflection-based assertions to generic variants where types are statically +resolvable and the semantics are preserved: + +```bash +# Dry-run to preview changes +migrate-testify --upgrade-generics --dry-run . + +# Apply changes +migrate-testify --upgrade-generics . +``` + +The tool is conservative: it only upgrades when: +- Argument types are statically known (no `any`, no `interface{}`) +- Types satisfy the required constraint (`comparable`, `Ordered`, `Text`, etc.) +- For `Equal`/`NotEqual`: types are "deeply comparable" (no pointers or structs with pointer fields) +- For `Contains`: the container type disambiguates to `StringContainsT`, `SliceContainsT`, or `MapContainsT` +- `IsType` is flagged for manual review (argument count changes) + +Assertions that cannot be safely upgraded are tracked and reported in the summary with +a specific reason (e.g., "pointer type", "interface{}/any", "type mismatch"). +Use `--verbose` to see the file and line of each skipped assertion. + +#### Reference + +``` +Usage: migrate-testify [flags] [directory] + +Migrate stretchr/testify to go-openapi/testify/v2 and upgrade to generic assertions. + +Flags: + -all Run both passes sequentially + -dry-run Show diffs without modifying files + -migrate Run pass 1: stretchr/testify -> go-openapi/testify/v2 + -upgrade-generics Run pass 2: reflection -> generic assertions + -verbose Print detailed transformation info + -skip-gomod Skip go.mod changes + -skip-vendor Skip vendor/ directory (default true) + -version string Target testify version (default "v2.3.0") + +At least one of --migrate, --upgrade-generics, or --all is required. + +Mono-repo support: + Pass 1 walks the filesystem and works across module boundaries. + Pass 2 requires type information and uses go/packages to load code. + For multi-module repos, a go.work file must be present so that pass 2 + can load all workspace modules. Create one with: + go work init . ./sub/module1 ./sub/module2 ... + +Post-migration checklist: + - Run your linter: the migration may surface pre-existing unchecked linting issues. + - Run your test suite to verify all tests still pass. +``` + +--- + +### Manual Migration + +#### 1. Update Import Paths ```go // Old -```go -import "github.com/stretchr/testify/v2" +import "github.com/stretchr/testify/assert" +import "github.com/stretchr/testify/require" // New -import "github.com/go-openapi/testify/v2" +import "github.com/go-openapi/testify/v2/assert" +import "github.com/go-openapi/testify/v2/require" ``` -### 2. Optional: Enable YAML Support +#### 2. Optional: Enable YAML Support If you use `YAMLEq` assertions: this feature is now opt-in. @@ -27,7 +136,7 @@ import _ "github.com/go-openapi/testify/enable/yaml/v2" Without this import, YAML assertions will panic with a helpful error message. -### 3. Optional: Enable Colorized Output +#### 3. Optional: Enable Colorized Output ```go import _ "github.com/go-openapi/testify/enable/colors/v2" @@ -43,11 +152,11 @@ go test -v -testify.colorized -testify.theme=light . ![Colorized Test](colorized.png) -### 4. Optional: Adopt Generic Assertions +#### 4. Optional: Adopt Generic Assertions For better type safety and performance, consider migrating to generic assertion variants. This is entirely optional—reflection-based assertions continue to work as before. -#### Step 1: Identify Generic-Capable Assertions +##### Identify Generic-Capable Assertions Look for these common assertions in your tests: @@ -75,8 +184,6 @@ assert.IsDecreasing → assert.IsDecreasingT assert.IsType(t, User{}, v) → assert.IsOfTypeT[User](t, v) // No dummy value! ``` -#### Step 2: Add Type Suffix - Simply add `T` to the function name. The compiler will check types automatically: ```go @@ -89,14 +196,14 @@ assert.EqualT(t, expected, actual) assert.ElementsMatchT(t, slice1, slice2) ``` -#### Step 3: Fix Type Mismatches +##### Fix Type Mismatches The compiler will now catch type errors. This is a feature—it reveals bugs: ```go // Compiler catches this assert.EqualT(t, int64(42), int32(42)) -// ❌ Error: mismatched types int64 and int32 +// Error: mismatched types int64 and int32 // Fix: Use same type assert.EqualT(t, int64(42), int64(actual)) @@ -105,7 +212,29 @@ assert.EqualT(t, int64(42), int64(actual)) assert.Equal(t, int64(42), int32(42)) // Still works ``` -#### Benefits of Migration +##### Pointer Semantics: When NOT to Upgrade + +Generic assertions use Go's `==` operator, while reflection-based assertions use `reflect.DeepEqual`. +For most types these are equivalent, but **they differ for pointers and structs containing pointers**: + +```go +a := &MyStruct{Name: "alice"} +b := &MyStruct{Name: "alice"} + +assert.Equal(t, a, b) // PASSES (reflect.DeepEqual compares pointed-to values) +assert.EqualT(t, a, b) // FAILS (== compares pointer addresses) +``` + +**Do not upgrade to generic variants when:** +- Arguments are pointer types (`*T`) — `EqualT` compares addresses, not values +- Arguments are structs with pointer fields — `==` compares field addresses, `DeepEqual` compares field values +- You intentionally rely on cross-type comparison (`int64` vs `int32`) + +The automated migration tool handles this automatically by only upgrading +assertions where the argument types are "deeply comparable" — types where `==` and +`reflect.DeepEqual` produce the same result. + +##### Benefits of Generic Assertions - **Compile-time type safety**: Catch errors when writing tests - **Performance**: 1.2x to 81x faster (see [Benchmarks](../project/maintainers/BENCHMARKS.md)) @@ -114,33 +243,62 @@ assert.Equal(t, int64(42), int32(42)) // Still works See the [Generics Guide](./GENERICS.md) for detailed usage patterns and best practices. -### 5. Remove Suite/Mock Usage +#### 5. Remove Suite/Mock Usage -Replace testify mocks with: +Replace testify mocks with: - [mockery](https://github.com/vektra/mockery) for mocking Replace testify suites with: - Standard Go subtests for test organization - or wait until we reintroduce this feature (possible, but not certain) -### 6. Remove use of the `testify/http` package +#### 6. Replace `go.uber.org/goleak` with `NoGoRoutineLeak` + +If you use `go.uber.org/goleak` to detect goroutine leaks in tests, consider replacing it +with `assert.NoGoRoutineLeak` (or `require.NoGoRoutineLeak`), which is built into testify v2. + +```go +// Before (with goleak) +import "go.uber.org/goleak" + +func TestNoLeak(t *testing.T) { + defer goleak.VerifyNone(t) + // ... test code ... +} + +// After (with testify v2) +import "github.com/go-openapi/testify/v2/assert" + +func TestNoLeak(t *testing.T) { + assert.NoGoRoutineLeak(t, func() { + // ... test code ... + }) +} +``` + +This removes the `go.uber.org/goleak` dependency. This step is not automated by the +migration tool. + +#### 7. Remove use of the `testify/http` package If you were still using the deprecated package `github.com/stretchr/testitfy/http`, you'll need to replace it by the standard `net/http/httptest` package. We won't reintroduce this package ever. +--- + ## Breaking Changes Summary ### Removed Packages -- ❌ `suite` - Use standard Go subtests -- ❌ `mock` - Use [mockery](https://github.com/vektra/mockery) -- ❌ `http` - May be reintroduced later +- `suite` - Use standard Go subtests +- `mock` - Use [mockery](https://github.com/vektra/mockery) +- `http` - May be reintroduced later ### Removed Functions and Types -- ❌ All deprecated functions from v1 removed -- ❌ Removed extraneous "helper" types: `PanicTestFunc` (`func()`) +- All deprecated functions from v1 removed +- Removed extraneous "helper" types: `PanicTestFunc` (`func()`) ### Behavior Changes @@ -156,4 +314,3 @@ Make sure to check the [behavior changes](./CHANGES.md) as we have fixed a few q - [Generics Guide](./GENERICS.md) - Learn about the 38 new type-safe generic assertions - [Usage Guide](./USAGE.md) - API conventions and how to navigate the documentation - [Tutorial](./TUTORIAL.md) - Best practices for writing tests with testify v2 - diff --git a/go.work b/go.work index ba6bd80e3..8db6c6486 100644 --- a/go.work +++ b/go.work @@ -3,6 +3,7 @@ use ( ./codegen ./enable/colors ./enable/yaml + ./hack/migrate-testify ./internal/testintegration ) diff --git a/go.work.sum b/go.work.sum index 8b205816e..d6f260c35 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,9 +1,21 @@ github.com/go-openapi/testify/v2 v2.0.1/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc h1:bH6xUXay0AIFMElXG2rQ4uiE+7ncwtiOdPfYK1NK2XA= golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= +golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= +golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= diff --git a/hack/migrate-testify/.gitignore b/hack/migrate-testify/.gitignore new file mode 100644 index 000000000..a4f000e3a --- /dev/null +++ b/hack/migrate-testify/.gitignore @@ -0,0 +1 @@ +migrate-testify diff --git a/hack/migrate-testify/constraints.go b/hack/migrate-testify/constraints.go new file mode 100644 index 000000000..26c16ca42 --- /dev/null +++ b/hack/migrate-testify/constraints.go @@ -0,0 +1,269 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "go/types" +) + +// isComparable reports whether typ satisfies the comparable constraint. +func isComparable(typ types.Type) bool { + return types.Comparable(typ) +} + +// isDeepComparable reports whether typ is comparable AND the == operator +// has the same semantics as reflect.DeepEqual. This is false for: +// - Pointer types (== compares addresses, DeepEqual compares targets) +// - Struct types containing pointer or interface fields +// - Interface types +// - Array types of non-deep-comparable elements +// +// This is critical for safe Equal→EqualT upgrades: EqualT uses == while +// Equal uses reflect.DeepEqual. +func isDeepComparable(typ types.Type) bool { + if !types.Comparable(typ) { + return false + } + return isDeepComparableUnderlying(typ, make(map[types.Type]bool)) +} + +func isDeepComparableUnderlying(typ types.Type, seen map[types.Type]bool) bool { + // Prevent infinite recursion on recursive types. + if seen[typ] { + return true + } + seen[typ] = true + + under := typ.Underlying() + + switch t := under.(type) { + case *types.Basic: + return true + case *types.Pointer: + return false + case *types.Interface: + return false + case *types.Struct: + for field := range t.Fields() { + if !isDeepComparableUnderlying(field.Type(), seen) { + return false + } + } + return true + case *types.Array: + return isDeepComparableUnderlying(t.Elem(), seen) + default: + return false + } +} + +// isOrdered reports whether typ satisfies the Ordered constraint +// (cmp.Ordered | []byte | time.Time). +func isOrdered(typ types.Type) bool { + // Check for time.Time (struct, not ~struct, so check named type). + if isTimeTime(typ) { + return true + } + + // Check for []byte. + if isByteSlice(typ) { + return true + } + + // Check if the underlying type is in cmp.Ordered (string, int*, uint*, float*). + return isCmpOrdered(typ) +} + +// isCmpOrdered checks if a type satisfies cmp.Ordered (basic ordered types). +func isCmpOrdered(typ types.Type) bool { + under := typ.Underlying() + basic, ok := under.(*types.Basic) + if !ok { + return false + } + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, + types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, + types.Uintptr, + types.Float32, types.Float64, + types.String: + return true + default: + return false + } +} + +// isText reports whether typ satisfies the Text constraint (~string | ~[]byte). +func isText(typ types.Type) bool { + under := typ.Underlying() + + // ~string + if basic, ok := under.(*types.Basic); ok && basic.Kind() == types.String { + return true + } + + // ~[]byte + return isByteSlice(typ) +} + +// isSignedNumeric reports whether typ satisfies the SignedNumeric constraint. +func isSignedNumeric(typ types.Type) bool { + under := typ.Underlying() + basic, ok := under.(*types.Basic) + if !ok { + return false + } + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, + types.Float32, types.Float64: + return true + default: + return false + } +} + +// isMeasurable reports whether typ satisfies the Measurable constraint +// (SignedNumeric | UnsignedNumeric | ~float32 | ~float64). +func isMeasurable(typ types.Type) bool { + under := typ.Underlying() + basic, ok := under.(*types.Basic) + if !ok { + return false + } + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64, + types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, + types.Float32, types.Float64: + return true + default: + return false + } +} + +// isBoolean reports whether typ satisfies the Boolean constraint (~bool). +func isBoolean(typ types.Type) bool { + under := typ.Underlying() + basic, ok := under.(*types.Basic) + if !ok { + return false + } + return basic.Kind() == types.Bool +} + +// isRegExp reports whether typ satisfies the RegExp constraint +// (Text | *regexp.Regexp). +func isRegExp(typ types.Type) bool { + if isText(typ) { + return true + } + return isRegexpPointer(typ) +} + +// isPointerType reports whether typ is a pointer type. Returns the element type. +func isPointerType(typ types.Type) (elem types.Type, ok bool) { + ptr, ok := typ.Underlying().(*types.Pointer) + if !ok { + return nil, false + } + return ptr.Elem(), true +} + +// isSliceType reports whether typ is a slice type. Returns the element type. +func isSliceType(typ types.Type) (elem types.Type, ok bool) { + sl, ok := typ.Underlying().(*types.Slice) + if !ok { + return nil, false + } + return sl.Elem(), true +} + +// isMapType reports whether typ is a map type. Returns key and value types. +func isMapType(typ types.Type) (key, val types.Type, ok bool) { + m, ok := typ.Underlying().(*types.Map) + if !ok { + return nil, nil, false + } + return m.Key(), m.Elem(), true +} + +// sameType reports whether two types are identical. +func sameType(a, b types.Type) bool { + return types.Identical(a, b) +} + +// isAnyOrInterface reports whether typ is interface{} or any, which means +// we cannot determine the concrete type statically. +func isAnyOrInterface(typ types.Type) bool { + if typ == nil { + return true + } + iface, ok := typ.Underlying().(*types.Interface) + if !ok { + return false + } + return iface.Empty() +} + +// satisfiesConstraint checks whether a type satisfies the given constraint kind. +func satisfiesConstraint(typ types.Type, c constraintKind) bool { + if isAnyOrInterface(typ) { + return false + } + switch c { + case constraintComparable: + return isComparable(typ) + case constraintDeepComparable: + return isDeepComparable(typ) + case constraintOrdered: + return isOrdered(typ) + case constraintText: + return isText(typ) + case constraintSignedNumeric: + return isSignedNumeric(typ) + case constraintMeasurable: + return isMeasurable(typ) + case constraintPointer: + _, ok := isPointerType(typ) + return ok + case constraintBoolean: + return isBoolean(typ) + case constraintRegExp: + return isRegExp(typ) + default: + return false + } +} + +// Helper functions + +func isByteSlice(typ types.Type) bool { + sl, ok := typ.Underlying().(*types.Slice) + if !ok { + return false + } + basic, ok := sl.Elem().(*types.Basic) + return ok && basic.Kind() == types.Byte +} + +func isTimeTime(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + obj := named.Obj() + return obj.Name() == "Time" && obj.Pkg() != nil && obj.Pkg().Path() == "time" +} + +func isRegexpPointer(typ types.Type) bool { + ptr, ok := typ.Underlying().(*types.Pointer) + if !ok { + return false + } + named, ok := ptr.Elem().(*types.Named) + if !ok { + return false + } + obj := named.Obj() + return obj.Name() == "Regexp" && obj.Pkg() != nil && obj.Pkg().Path() == "regexp" +} diff --git a/hack/migrate-testify/constraints_test.go b/hack/migrate-testify/constraints_test.go new file mode 100644 index 000000000..006da7db9 --- /dev/null +++ b/hack/migrate-testify/constraints_test.go @@ -0,0 +1,200 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "go/token" + "go/types" + "iter" + "slices" + "testing" +) + +type constraintTestCase struct { + name string + typ types.Type + constraint constraintKind + expected bool +} + +func constraintTestCases() iter.Seq[constraintTestCase] { + // Create some test types. + intType := types.Typ[types.Int] + int64Type := types.Typ[types.Int64] + float64Type := types.Typ[types.Float64] + stringType := types.Typ[types.String] + boolType := types.Typ[types.Bool] + uint8Type := types.Typ[types.Uint8] + complex128Type := types.Typ[types.Complex128] + + // []byte + byteSlice := types.NewSlice(types.Typ[types.Byte]) + + // *int + intPtr := types.NewPointer(intType) + + // interface{} + emptyIface := types.NewInterfaceType(nil, nil) + + return slices.Values([]constraintTestCase{ + // comparable + {name: "int is comparable", typ: intType, constraint: constraintComparable, expected: true}, + {name: "string is comparable", typ: stringType, constraint: constraintComparable, expected: true}, + {name: "bool is comparable", typ: boolType, constraint: constraintComparable, expected: true}, + + // ordered + {name: "int is ordered", typ: intType, constraint: constraintOrdered, expected: true}, + {name: "float64 is ordered", typ: float64Type, constraint: constraintOrdered, expected: true}, + {name: "string is ordered", typ: stringType, constraint: constraintOrdered, expected: true}, + {name: "[]byte is ordered", typ: byteSlice, constraint: constraintOrdered, expected: true}, + {name: "bool is NOT ordered", typ: boolType, constraint: constraintOrdered, expected: false}, + + // text + {name: "string is text", typ: stringType, constraint: constraintText, expected: true}, + {name: "[]byte is text", typ: byteSlice, constraint: constraintText, expected: true}, + {name: "int is NOT text", typ: intType, constraint: constraintText, expected: false}, + + // signedNumeric + {name: "int is signedNumeric", typ: intType, constraint: constraintSignedNumeric, expected: true}, + {name: "int64 is signedNumeric", typ: int64Type, constraint: constraintSignedNumeric, expected: true}, + {name: "float64 is signedNumeric", typ: float64Type, constraint: constraintSignedNumeric, expected: true}, + {name: "uint8 is NOT signedNumeric", typ: uint8Type, constraint: constraintSignedNumeric, expected: false}, + + // measurable + {name: "int is measurable", typ: intType, constraint: constraintMeasurable, expected: true}, + {name: "uint8 is measurable", typ: uint8Type, constraint: constraintMeasurable, expected: true}, + {name: "float64 is measurable", typ: float64Type, constraint: constraintMeasurable, expected: true}, + {name: "complex128 is NOT measurable", typ: complex128Type, constraint: constraintMeasurable, expected: false}, + {name: "string is NOT measurable", typ: stringType, constraint: constraintMeasurable, expected: false}, + + // boolean + {name: "bool is boolean", typ: boolType, constraint: constraintBoolean, expected: true}, + {name: "int is NOT boolean", typ: intType, constraint: constraintBoolean, expected: false}, + + // pointer + {name: "*int is pointer", typ: intPtr, constraint: constraintPointer, expected: true}, + {name: "int is NOT pointer", typ: intType, constraint: constraintPointer, expected: false}, + + // deepComparable — same as comparable but excludes pointers and structs with pointers + {name: "int is deepComparable", typ: intType, constraint: constraintDeepComparable, expected: true}, + {name: "string is deepComparable", typ: stringType, constraint: constraintDeepComparable, expected: true}, + {name: "bool is deepComparable", typ: boolType, constraint: constraintDeepComparable, expected: true}, + {name: "*int is NOT deepComparable", typ: intPtr, constraint: constraintDeepComparable, expected: false}, + + // any/interface{} should not satisfy anything + {name: "interface{} is NOT comparable", typ: emptyIface, constraint: constraintComparable, expected: false}, + {name: "interface{} is NOT ordered", typ: emptyIface, constraint: constraintOrdered, expected: false}, + {name: "interface{} is NOT deepComparable", typ: emptyIface, constraint: constraintDeepComparable, expected: false}, + }) +} + +func TestSatisfiesConstraint(t *testing.T) { + t.Parallel() + + for c := range constraintTestCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + got := satisfiesConstraint(c.typ, c.constraint) + if got != c.expected { + t.Errorf("satisfiesConstraint(%v, %v) = %v, want %v", c.typ, c.constraint, got, c.expected) + } + }) + } +} + +func TestSameType(t *testing.T) { + t.Parallel() + + intType := types.Typ[types.Int] + int64Type := types.Typ[types.Int64] + stringType := types.Typ[types.String] + + if !sameType(intType, intType) { + t.Error("int should be same type as int") + } + if sameType(intType, int64Type) { + t.Error("int should NOT be same type as int64") + } + if sameType(intType, stringType) { + t.Error("int should NOT be same type as string") + } +} + +func TestDeepComparable(t *testing.T) { + t.Parallel() + + // Struct with no pointers — deep-comparable. + plainStruct := types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, nil, "X", types.Typ[types.Int]), + types.NewVar(token.NoPos, nil, "Y", types.Typ[types.String]), + }, nil) + if !isDeepComparable(plainStruct) { + t.Error("struct{X int; Y string} should be deep-comparable") + } + + // Struct with a pointer field — NOT deep-comparable. + ptrStruct := types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, nil, "X", types.Typ[types.Int]), + types.NewVar(token.NoPos, nil, "P", types.NewPointer(types.Typ[types.Int])), + }, nil) + if isDeepComparable(ptrStruct) { + t.Error("struct{X int; P *int} should NOT be deep-comparable") + } + + // Struct with an interface field — NOT deep-comparable. + ifaceStruct := types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, nil, "X", types.Typ[types.Int]), + types.NewVar(token.NoPos, nil, "E", types.NewInterfaceType(nil, nil)), + }, nil) + if isDeepComparable(ifaceStruct) { + t.Error("struct{X int; E interface{}} should NOT be deep-comparable") + } + + // Array of int — deep-comparable. + intArray := types.NewArray(types.Typ[types.Int], 3) + if !isDeepComparable(intArray) { + t.Error("[3]int should be deep-comparable") + } + + // Array of *int — NOT deep-comparable. + ptrArray := types.NewArray(types.NewPointer(types.Typ[types.Int]), 3) + if isDeepComparable(ptrArray) { + t.Error("[3]*int should NOT be deep-comparable") + } +} + +func TestIsSliceType(t *testing.T) { + t.Parallel() + + intSlice := types.NewSlice(types.Typ[types.Int]) + elem, ok := isSliceType(intSlice) + if !ok { + t.Fatal("expected []int to be a slice type") + } + if !sameType(elem, types.Typ[types.Int]) { + t.Errorf("expected slice element to be int, got %v", elem) + } + + _, ok = isSliceType(types.Typ[types.Int]) + if ok { + t.Error("int should not be a slice type") + } +} + +func TestIsMapType(t *testing.T) { + t.Parallel() + + m := types.NewMap(types.Typ[types.String], types.Typ[types.Int]) + key, val, ok := isMapType(m) + if !ok { + t.Fatal("expected map[string]int to be a map type") + } + if !sameType(key, types.Typ[types.String]) { + t.Errorf("expected key to be string, got %v", key) + } + if !sameType(val, types.Typ[types.Int]) { + t.Errorf("expected val to be int, got %v", val) + } +} diff --git a/hack/migrate-testify/generics.go b/hack/migrate-testify/generics.go new file mode 100644 index 000000000..1e6b29eb1 --- /dev/null +++ b/hack/migrate-testify/generics.go @@ -0,0 +1,536 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/packages" +) + +// minPairArgs is the minimum number of assertion arguments (after t) needed +// for binary assertions like Equal(t, expected, actual). +const minPairArgs = 2 + +// runGenericUpgrade executes pass 2: reflection → generic assertion upgrade. +func runGenericUpgrade(dir string, opts *options) error { + pkgs, fset, err := loadPackages(dir) + if err != nil { + return err + } + + rpt := &report{} + + for _, pkg := range pkgs { + if pkg.TypesInfo == nil { + continue + } + + for i, f := range pkg.Syntax { + if i >= len(pkg.GoFiles) { + continue + } + filename := pkg.GoFiles[i] + + if opts.skipVendor && isVendorPath(filename) { + continue + } + + if !fileImportsAny(f, "github.com/go-openapi/testify") { + continue + } + + rpt.filesScanned++ + changes := upgradeFile(f, pkg, fset, rpt, filename, opts.verbose) + if changes == 0 { + continue + } + + rpt.filesChanged++ + rpt.totalChanges += changes + + if opts.dryRun { + if err := showDiff(fset, f, filename); err != nil { + rpt.errorf(filename, 0, err.Error()) + } + } else { + if err := writeFile(fset, f, filename); err != nil { + rpt.errorf(filename, 0, err.Error()) + } + } + } + } + + rpt.print(opts.verbose) + rpt.printPass2Summary() + return nil +} + +// upgradeFile processes a single file for generic upgrades. +func upgradeFile(f *ast.File, pkg *packages.Package, fset *token.FileSet, rpt *report, filename string, verbose bool) int { + aliases := buildGoOpenapiAliasMap(f) + if len(aliases) == 0 { + return 0 + } + + changes := 0 + + ast.Inspect(f, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, funcName, isTestifyCall := extractTestifyCall(call, aliases) + if !isTestifyCall { + return true + } + + // Strip trailing "f" for format variants. + baseName := funcName + isFormat := false + if strings.HasSuffix(baseName, "f") && baseName != "Equalf" { + // Check if baseName without 'f' is in the upgrade table. + candidate := baseName[:len(baseName)-1] + if _, ok := genericUpgrades[candidate]; ok { + baseName = candidate + isFormat = true + } + } + // Also check Equalf directly. + if baseName == "Equalf" { + baseName = "Equal" + isFormat = true + } + + rule, ok := genericUpgrades[baseName] + if !ok { + return true + } + + if rule.manualReview { + pos := fset.Position(call.Pos()) + rpt.warn(filename, pos.Line, + fmt.Sprintf("%s → %s requires manual review (argument count changes)", funcName, rule.target)) + return true + } + + if rule.containerUpgrade { + if upgraded := tryContainerUpgrade(call, sel, funcName, baseName, isFormat, pkg, fset, rpt, filename, verbose); upgraded { + changes++ + } + return true + } + + if upgraded := trySimpleUpgrade(call, sel, funcName, baseName, isFormat, rule, pkg, fset, rpt, filename, verbose); upgraded { + changes++ + } + + return true + }) + + return changes +} + +// buildGoOpenapiAliasMap builds an alias map for go-openapi/testify imports. +func buildGoOpenapiAliasMap(f *ast.File) map[string]string { + aliases := make(map[string]string) + for _, imp := range f.Imports { + path := importPath(imp) + if !strings.HasPrefix(path, "github.com/go-openapi/testify") { + continue + } + // Skip enable imports. + if strings.Contains(path, "/enable/") { + continue + } + + var localName string + if imp.Name != nil { + localName = imp.Name.Name + } else { + parts := strings.Split(path, "/") + localName = parts[len(parts)-1] + } + if localName != "_" && localName != "." { + aliases[localName] = path + } + } + return aliases +} + +// extractTestifyCall checks if a call expression is a testify assertion call. +// Returns the selector, function name, and whether it's a testify call. +func extractTestifyCall(call *ast.CallExpr, aliases map[string]string) (*ast.SelectorExpr, string, bool) { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil, "", false + } + + funcName := sel.Sel.Name + + // Package-level call: assert.Equal(t, ...) + if ident, ok := sel.X.(*ast.Ident); ok { + if _, exists := aliases[ident.Name]; exists { + return sel, funcName, true + } + } + + return nil, "", false +} + +// checkResult is the outcome of a constraint check for a generic upgrade. +type checkResult struct { + ok bool // constraints satisfied — upgrade is safe + reason skipReason // non-empty: skip with this reason + typeInfo string // context for skip message (e.g. type name) +} + +var ( + checkOK = checkResult{ok: true} //nolint:gochecknoglobals // convenience value + checkInsufficient = checkResult{} //nolint:gochecknoglobals // not enough args, silent skip +) + +func checkSkip(r skipReason, info string) checkResult { + return checkResult{reason: r, typeInfo: info} +} + +// trySimpleUpgrade attempts to upgrade a simple (non-container) assertion. +func trySimpleUpgrade( + call *ast.CallExpr, + sel *ast.SelectorExpr, + funcName, baseName string, + isFormat bool, + rule upgradeRule, + pkg *packages.Package, + fset *token.FileSet, + rpt *report, + filename string, + verbose bool, +) bool { + pos := fset.Position(call.Pos()) + + skip := func(reason skipReason, typeInfo string) bool { + rpt.trackSkip(filename, pos.Line, funcName, reason, verbose, typeInfo) + return false + } + + argTypes, argSkipReason := getArgTypesWithReason(call, pkg, 1) + if argTypes == nil { + return skip(argSkipReason, "") + } + + var result checkResult + switch baseName { + case "Equal", "NotEqual": + result = checkDeepComparablePair(argTypes, rule) + case "Greater", "GreaterOrEqual", "Less", "LessOrEqual": + result = checkPairConstraint(argTypes, constraintOrdered, skipNotOrdered, rule) + case "InDelta", "InEpsilon": + result = checkPairConstraint(argTypes, constraintMeasurable, skipNotMeasurable, rule) + case "Positive", "Negative": + result = checkSingleConstraint(argTypes, constraintSignedNumeric, skipNotSignedNumeric) + case "True", "False": + result = checkSingleConstraint(argTypes, constraintBoolean, skipNotBoolean) + case "Same", "NotSame": + result = checkPairConstraint(argTypes, constraintPointer, skipNotPointer, rule) + case "ElementsMatch", "Subset": + result = checkSlicePairConstraint(argTypes, rule) + case "IsIncreasing", "IsDecreasing", "IsNonIncreasing", "IsNonDecreasing": + result = checkOrderedSliceConstraint(argTypes) + case "Regexp", "NotRegexp": + result = checkRegexpConstraint(argTypes) + case "JSONEq", "YAMLEq": + result = checkPairConstraint(argTypes, constraintText, skipNotText, rule) + default: + return false + } + + if !result.ok { + if result.reason != "" { + return skip(result.reason, result.typeInfo) + } + return false + } + + newName := rule.target + if isFormat { + newName += "f" + } + + if verbose { + rpt.info(filename, pos.Line, fmt.Sprintf("upgraded %s → %s", funcName, newName)) + } + rpt.trackUpgrade(funcName, newName) + + sel.Sel.Name = newName + return true +} + +// checkDeepComparablePair checks that both arguments are deeply comparable and have matching types. +func checkDeepComparablePair(argTypes []types.Type, rule upgradeRule) checkResult { + if len(argTypes) < minPairArgs { + return checkInsufficient + } + if reason := deepComparableSkipReason(argTypes[0]); reason != "" { + return checkSkip(reason, argTypes[0].String()) + } + if reason := deepComparableSkipReason(argTypes[1]); reason != "" { + return checkSkip(reason, argTypes[1].String()) + } + if rule.sameType && !sameType(argTypes[0], argTypes[1]) { + return checkSkip(skipTypeMismatch, argTypes[0].String()+" vs "+argTypes[1].String()) + } + return checkOK +} + +// checkPairConstraint checks that both arguments satisfy a constraint and have matching types. +func checkPairConstraint(argTypes []types.Type, c constraintKind, failReason skipReason, rule upgradeRule) checkResult { + if len(argTypes) < minPairArgs { + return checkInsufficient + } + if !satisfiesConstraint(argTypes[0], c) || !satisfiesConstraint(argTypes[1], c) { + return checkSkip(failReason, argTypes[0].String()) + } + if rule.sameType && !sameType(argTypes[0], argTypes[1]) { + return checkSkip(skipTypeMismatch, argTypes[0].String()+" vs "+argTypes[1].String()) + } + return checkOK +} + +// checkSingleConstraint checks that the first argument satisfies a constraint. +func checkSingleConstraint(argTypes []types.Type, c constraintKind, failReason skipReason) checkResult { + if len(argTypes) < 1 { + return checkInsufficient + } + if !satisfiesConstraint(argTypes[0], c) { + return checkSkip(failReason, argTypes[0].String()) + } + return checkOK +} + +// checkSlicePairConstraint checks that both arguments are slices with deep-comparable elements. +func checkSlicePairConstraint(argTypes []types.Type, rule upgradeRule) checkResult { + if len(argTypes) < minPairArgs { + return checkInsufficient + } + elem0, ok0 := isSliceType(argTypes[0]) + elem1, ok1 := isSliceType(argTypes[1]) + if !ok0 || !ok1 { + return checkSkip(skipNotSlice, argTypes[0].String()) + } + if !isDeepComparable(elem0) || !isDeepComparable(elem1) { + return checkSkip(skipSliceElemNotDeepComparable, elem0.String()) + } + if rule.sameType && !sameType(argTypes[0], argTypes[1]) { + return checkSkip(skipTypeMismatch, argTypes[0].String()+" vs "+argTypes[1].String()) + } + return checkOK +} + +// checkOrderedSliceConstraint checks that the argument is a slice of ordered elements. +func checkOrderedSliceConstraint(argTypes []types.Type) checkResult { + if len(argTypes) < 1 { + return checkInsufficient + } + elem, ok := isSliceType(argTypes[0]) + if !ok { + return checkSkip(skipNotSlice, argTypes[0].String()) + } + if !isOrdered(elem) { + return checkSkip(skipSliceElemNotOrdered, elem.String()) + } + return checkOK +} + +// checkRegexpConstraint checks that the first arg is a RegExp and the second is Text. +func checkRegexpConstraint(argTypes []types.Type) checkResult { + if len(argTypes) < minPairArgs { + return checkInsufficient + } + if !satisfiesConstraint(argTypes[0], constraintRegExp) { + return checkSkip(skipNotRegExp, argTypes[0].String()) + } + if !satisfiesConstraint(argTypes[1], constraintText) { + return checkSkip(skipNotText, argTypes[1].String()) + } + return checkOK +} + +// deepComparableSkipReason returns the specific skip reason if a type is not deep-comparable, +// or an empty string if it is. +func deepComparableSkipReason(typ types.Type) skipReason { + if !types.Comparable(typ) { + return skipNotComparable + } + under := typ.Underlying() + switch under.(type) { + case *types.Pointer: + return skipPointerSemantics + case *types.Interface: + return skipInterfaceField + case *types.Struct: + if !isDeepComparable(typ) { + return skipInterfaceField + } + } + if !isDeepComparable(typ) { + return skipNotComparable + } + return "" +} + +// containerCheckResult extends checkResult with the resolved target function name. +type containerCheckResult struct { + checkResult + + target string +} + +// tryContainerUpgrade handles Contains/NotContains which dispatch to different +// generic variants based on the container type. +func tryContainerUpgrade( + call *ast.CallExpr, + sel *ast.SelectorExpr, + funcName, baseName string, + isFormat bool, + pkg *packages.Package, + fset *token.FileSet, + rpt *report, + filename string, + verbose bool, +) bool { + pos := fset.Position(call.Pos()) + + skip := func(reason skipReason, typeInfo string) bool { + rpt.trackSkip(filename, pos.Line, funcName, reason, verbose, typeInfo) + return false + } + + argTypes, argSkipReason := getArgTypesWithReason(call, pkg, 1) + if argTypes == nil { + return skip(argSkipReason, "") + } + if len(argTypes) < minPairArgs { + return false + } + + isNot := baseName == "NotContains" + + var result containerCheckResult + switch { + case isText(argTypes[0]): + result = checkStringContains(argTypes, isNot) + case isSliceType2(argTypes[0]): + result = checkSliceContains(argTypes, isNot) + case isMapType2(argTypes[0]): + result = checkMapContains(argTypes, isNot) + default: + return skip(skipContainerTypeUnknown, argTypes[0].String()) + } + + if !result.ok { + if result.reason != "" { + return skip(result.reason, result.typeInfo) + } + return false + } + + target := result.target + if isFormat { + target += "f" + } + + if verbose { + rpt.info(filename, pos.Line, fmt.Sprintf("upgraded %s → %s", funcName, target)) + } + rpt.trackUpgrade(funcName, target) + + sel.Sel.Name = target + return true +} + +// isSliceType2 reports whether typ is a slice type (bool-only variant for switch). +func isSliceType2(typ types.Type) bool { + _, ok := typ.Underlying().(*types.Slice) + return ok +} + +// isMapType2 reports whether typ is a map type (bool-only variant for switch). +func isMapType2(typ types.Type) bool { + _, ok := typ.Underlying().(*types.Map) + return ok +} + +// pickTarget selects the contains or not-contains variant. +func pickTarget(kind containerKind, isNot bool) string { + targets := containerUpgradeTargets[kind] + if isNot { + return targets[1] + } + return targets[0] +} + +// checkStringContains validates a string Contains/NotContains upgrade. +func checkStringContains(argTypes []types.Type, isNot bool) containerCheckResult { + if !isText(argTypes[1]) { + return containerCheckResult{checkResult: checkSkip(skipNotText, argTypes[1].String())} + } + return containerCheckResult{checkResult: checkOK, target: pickTarget(containerString, isNot)} +} + +// checkSliceContains validates a slice Contains/NotContains upgrade. +func checkSliceContains(argTypes []types.Type, isNot bool) containerCheckResult { + elem, _ := isSliceType(argTypes[0]) + if !isDeepComparable(elem) { + return containerCheckResult{checkResult: checkSkip(skipSliceElemNotDeepComparable, elem.String())} + } + if isAnyOrInterface(argTypes[1]) { + return containerCheckResult{checkResult: checkSkip(skipInterfaceType, argTypes[1].String())} + } + if !sameType(elem, argTypes[1]) { + return containerCheckResult{checkResult: checkSkip(skipTypeMismatch, elem.String()+" vs "+argTypes[1].String())} + } + return containerCheckResult{checkResult: checkOK, target: pickTarget(containerSlice, isNot)} +} + +// checkMapContains validates a map Contains/NotContains upgrade. +func checkMapContains(argTypes []types.Type, isNot bool) containerCheckResult { + key, _, _ := isMapType(argTypes[0]) + if !isComparable(key) { + return containerCheckResult{checkResult: checkSkip(skipNotComparable, key.String())} + } + if isAnyOrInterface(argTypes[1]) { + return containerCheckResult{checkResult: checkSkip(skipInterfaceType, argTypes[1].String())} + } + if !sameType(key, argTypes[1]) { + return containerCheckResult{checkResult: checkSkip(skipTypeMismatch, key.String()+" vs "+argTypes[1].String())} + } + return containerCheckResult{checkResult: checkOK, target: pickTarget(containerMap, isNot)} +} + +// getArgTypesWithReason extracts the types of call arguments starting from the given offset +// (to skip the testing.T parameter). Returns nil and a skip reason if any type cannot be resolved. +func getArgTypesWithReason(call *ast.CallExpr, pkg *packages.Package, offset int) ([]types.Type, skipReason) { + if len(call.Args) <= offset { + return nil, skipUnresolvableType + } + + result := make([]types.Type, 0, len(call.Args)-offset) + for _, arg := range call.Args[offset:] { + tv, ok := pkg.TypesInfo.Types[arg] + if !ok { + return nil, skipUnresolvableType + } + if isAnyOrInterface(tv.Type) { + return nil, skipInterfaceType + } + result = append(result, tv.Type) + } + return result, "" +} diff --git a/hack/migrate-testify/generics_test.go b/hack/migrate-testify/generics_test.go new file mode 100644 index 000000000..7ff9a969a --- /dev/null +++ b/hack/migrate-testify/generics_test.go @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/printer" + "go/token" + "go/types" + "iter" + "slices" + "strings" + "testing" + + "golang.org/x/tools/go/packages" +) + +type genericsTestCase struct { + name string + input string + expected string +} + +func genericsTestCases() iter.Seq[genericsTestCase] { + return slices.Values([]genericsTestCase{ + { + name: "equal int upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestEqual(t *testing.T) { + assert.Equal(t, 42, 42) +}`, + expected: `assert.EqualT(t, 42, 42)`, + }, + { + name: "notequal string upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestNotEqual(t *testing.T) { + assert.NotEqual(t, "a", "b") +}`, + expected: `assert.NotEqualT(t, "a", "b")`, + }, + { + name: "greater upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestGreater(t *testing.T) { + assert.Greater(t, 2, 1) +}`, + expected: `assert.GreaterT(t, 2, 1)`, + }, + { + name: "positive upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestPositive(t *testing.T) { + assert.Positive(t, 42) +}`, + expected: `assert.PositiveT(t, 42)`, + }, + { + name: "skip any", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestSkipAny(t *testing.T) { + var x any = 42 + assert.Equal(t, x, x) +}`, + expected: `assert.Equal(t, x, x)`, + }, + { + name: "skip different types", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestSkipDifferent(t *testing.T) { + assert.Equal(t, int32(1), int64(1)) +}`, + expected: `assert.Equal(t, int32(1), int64(1))`, + }, + { + name: "contains string upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestContains(t *testing.T) { + assert.Contains(t, "hello world", "world") +}`, + expected: `assert.StringContainsT(t, "hello world", "world")`, + }, + { + name: "contains slice upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestContains(t *testing.T) { + assert.Contains(t, []int{1, 2, 3}, 2) +}`, + expected: `assert.SliceContainsT(t, []int{1, 2, 3}, 2)`, + }, + { + name: "contains map upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestContains(t *testing.T) { + assert.Contains(t, map[string]int{"a": 1}, "a") +}`, + expected: `assert.MapContainsT(t, map[string]int{"a": 1}, "a")`, + }, + { + name: "true/false bool upgrade", + input: `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestBool(t *testing.T) { + assert.True(t, true) + assert.False(t, false) +}`, + expected: "assert.TrueT(t, true)\n\tassert.FalseT(t, false)", + }, + }) +} + +func TestGenericUpgradeUnit(t *testing.T) { + t.Parallel() + + for c := range genericsTestCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", c.input, parser.ParseComments) + if err != nil { + t.Fatalf("parse: %v", err) + } + + // Create a mock assert package so type-checking succeeds. + info := typeCheckWithMockAssert(t, fset, f) + if info == nil { + t.Fatal("type-check failed") + } + + // Build a fake *packages.Package to pass to upgradeFile. + pkg := &packages.Package{ + TypesInfo: info, + Syntax: []*ast.File{f}, + GoFiles: []string{"test.go"}, + } + + rpt := &report{} + upgradeFile(f, pkg, fset, rpt, "test.go", true) + + // Extract the assertion call(s) from the output. + var buf strings.Builder + pcfg := &printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8} + if err := pcfg.Fprint(&buf, fset, f); err != nil { + t.Fatalf("print: %v", err) + } + + got := buf.String() + if !strings.Contains(got, c.expected) { + t.Errorf("expected output to contain:\n %s\ngot:\n%s", c.expected, got) + } + }) + } +} + +func TestGenericUpgradeTracksUpgrades(t *testing.T) { + t.Parallel() + + input := `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestEqual(t *testing.T) { + assert.Equal(t, 42, 42) + assert.True(t, true) +}` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", input, parser.ParseComments) + if err != nil { + t.Fatalf("parse: %v", err) + } + + info := typeCheckWithMockAssert(t, fset, f) + if info == nil { + t.Fatal("type-check failed") + } + + pkg := &packages.Package{ + TypesInfo: info, + Syntax: []*ast.File{f}, + GoFiles: []string{"test.go"}, + } + + rpt := &report{} + changes := upgradeFile(f, pkg, fset, rpt, "test.go", false) + + if changes != 2 { + t.Errorf("expected 2 changes, got %d", changes) + } + if len(rpt.upgraded) != 2 { + t.Errorf("expected 2 upgrade entries, got %d", len(rpt.upgraded)) + } + if rpt.upgraded["Equal → EqualT"] != 1 { + t.Errorf("expected Equal → EqualT upgrade, got %v", rpt.upgraded) + } + if rpt.upgraded["True → TrueT"] != 1 { + t.Errorf("expected True → TrueT upgrade, got %v", rpt.upgraded) + } +} + +func TestGenericUpgradeTracksSkips(t *testing.T) { + t.Parallel() + + // Test with any type — should be skipped with skipInterfaceType. + input := `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestSkipAny(t *testing.T) { + var x any = 42 + assert.Equal(t, x, x) +}` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", input, parser.ParseComments) + if err != nil { + t.Fatalf("parse: %v", err) + } + + info := typeCheckWithMockAssert(t, fset, f) + if info == nil { + t.Fatal("type-check failed") + } + + pkg := &packages.Package{ + TypesInfo: info, + Syntax: []*ast.File{f}, + GoFiles: []string{"test.go"}, + } + + rpt := &report{} + changes := upgradeFile(f, pkg, fset, rpt, "test.go", true) + + if changes != 0 { + t.Errorf("expected 0 changes for any type, got %d", changes) + } + if len(rpt.skipped) == 0 { + t.Error("expected skip to be tracked for any type") + } + if rpt.skipped[string(skipInterfaceType)] != 1 { + t.Errorf("expected skipInterfaceType, got %v", rpt.skipped) + } +} + +func TestGenericUpgradeTracksDifferentTypeSkip(t *testing.T) { + t.Parallel() + + input := `package p +import ( + "testing" + "github.com/go-openapi/testify/v2/assert" +) +func TestSkipDifferent(t *testing.T) { + assert.Equal(t, int32(1), int64(1)) +}` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", input, parser.ParseComments) + if err != nil { + t.Fatalf("parse: %v", err) + } + + info := typeCheckWithMockAssert(t, fset, f) + if info == nil { + t.Fatal("type-check failed") + } + + pkg := &packages.Package{ + TypesInfo: info, + Syntax: []*ast.File{f}, + GoFiles: []string{"test.go"}, + } + + rpt := &report{} + changes := upgradeFile(f, pkg, fset, rpt, "test.go", true) + + if changes != 0 { + t.Errorf("expected 0 changes for mismatched types, got %d", changes) + } + if len(rpt.skipped) == 0 { + t.Error("expected skip to be tracked for type mismatch") + } + if rpt.skipped[string(skipTypeMismatch)] != 1 { + t.Errorf("expected skipTypeMismatch, got %v", rpt.skipped) + } +} + +// typeCheckWithMockAssert creates a mock "assert" package that has a few functions +// taking (testing.T, any...) and type-checks the file against it. +func typeCheckWithMockAssert(t *testing.T, fset *token.FileSet, f *ast.File) *types.Info { + t.Helper() + + // We need to provide packages that the source imports. + // Use the default importer for stdlib. + stdImporter := importer.ForCompiler(fset, "source", nil) + + // Create a fake "assert" package with the functions we need. + assertPkg := types.NewPackage("github.com/go-openapi/testify/v2/assert", "assert") + + // T interface (just needs Errorf and Helper). + tParam := types.NewVar(token.NoPos, assertPkg, "t", types.NewInterfaceType(nil, nil)) + anyType := types.NewInterfaceType(nil, nil) + + // Create variadic msgAndArgs param. + msgAndArgsVar := types.NewVar(token.NoPos, assertPkg, "msgAndArgs", types.NewSlice(anyType)) + + // Helper to create assertion function signatures. + makeAssertFunc := func(name string, params ...*types.Var) { + allParams := make([]*types.Var, 0, 1+len(params)+1) + allParams = append(allParams, tParam) + allParams = append(allParams, params...) + allParams = append(allParams, msgAndArgsVar) + + sig := types.NewSignatureType(nil, nil, nil, + types.NewTuple(allParams...), + types.NewTuple(types.NewVar(token.NoPos, assertPkg, "", types.Typ[types.Bool])), + true, // variadic + ) + assertPkg.Scope().Insert(types.NewFunc(token.NoPos, assertPkg, name, sig)) + } + + // Add known assertion functions to the fake package. + anyParam := func(name string) *types.Var { + return types.NewVar(token.NoPos, assertPkg, name, anyType) + } + + // Equal(t, expected any, actual any, ...) + makeAssertFunc("Equal", anyParam("expected"), anyParam("actual")) + makeAssertFunc("EqualT", anyParam("expected"), anyParam("actual")) + makeAssertFunc("NotEqual", anyParam("expected"), anyParam("actual")) + makeAssertFunc("NotEqualT", anyParam("expected"), anyParam("actual")) + makeAssertFunc("Greater", anyParam("e1"), anyParam("e2")) + makeAssertFunc("GreaterT", anyParam("e1"), anyParam("e2")) + makeAssertFunc("Less", anyParam("e1"), anyParam("e2")) + makeAssertFunc("LessT", anyParam("e1"), anyParam("e2")) + makeAssertFunc("Positive", anyParam("e")) + makeAssertFunc("PositiveT", anyParam("e")) + makeAssertFunc("Negative", anyParam("e")) + makeAssertFunc("NegativeT", anyParam("e")) + makeAssertFunc("Contains", anyParam("s"), anyParam("contains")) + makeAssertFunc("StringContainsT", anyParam("s"), anyParam("contains")) + makeAssertFunc("SliceContainsT", anyParam("s"), anyParam("contains")) + makeAssertFunc("MapContainsT", anyParam("s"), anyParam("contains")) + makeAssertFunc("True", anyParam("value")) + makeAssertFunc("TrueT", anyParam("value")) + makeAssertFunc("False", anyParam("value")) + makeAssertFunc("FalseT", anyParam("value")) + + assertPkg.MarkComplete() + + // Custom importer that returns our mock assert package. + customImporter := &mockImporter{ + base: stdImporter, + extra: map[string]*types.Package{ + "github.com/go-openapi/testify/v2/assert": assertPkg, + }, + } + + conf := &types.Config{ + Importer: customImporter, + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + + _, err := conf.Check("test", fset, []*ast.File{f}, info) + if err != nil { + t.Logf("type-check error (may be expected): %v", err) + // Even with errors, the info map may have partial results. + // Check if we have any type info at all. + if len(info.Types) == 0 { + return nil + } + } + + return info +} + +type mockImporter struct { + base types.Importer + extra map[string]*types.Package +} + +func (m *mockImporter) Import(path string) (*types.Package, error) { + if pkg, ok := m.extra[path]; ok { + return pkg, nil + } + return m.base.Import(path) +} diff --git a/hack/migrate-testify/go.mod b/hack/migrate-testify/go.mod new file mode 100644 index 000000000..5f9e5e892 --- /dev/null +++ b/hack/migrate-testify/go.mod @@ -0,0 +1,10 @@ +module github.com/go-openapi/testify/hack/migrate-testify/v2 + +go 1.24.0 + +require ( + golang.org/x/mod v0.33.0 + golang.org/x/tools v0.41.0 +) + +require golang.org/x/sync v0.19.0 // indirect diff --git a/hack/migrate-testify/go.sum b/hack/migrate-testify/go.sum new file mode 100644 index 000000000..927c44439 --- /dev/null +++ b/hack/migrate-testify/go.sum @@ -0,0 +1,8 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= diff --git a/hack/migrate-testify/gomod.go b/hack/migrate-testify/gomod.go new file mode 100644 index 000000000..f03111564 --- /dev/null +++ b/hack/migrate-testify/gomod.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" +) + +// updateGoMod updates the target project's go.mod to replace stretchr/testify +// with go-openapi/testify/v2. +func updateGoMod(dir, version string, dryRun, verbose bool) error { + // Check go.mod exists. + if _, err := os.Stat(dir + "/go.mod"); err != nil { + return fmt.Errorf("no go.mod found in %s", dir) + } + + if dryRun { + fmt.Println("would run: go mod edit -droprequire github.com/stretchr/testify") //nolint:forbidigo // CLI output + fmt.Printf("would run: go mod edit -require github.com/go-openapi/testify/v2@%s\n", version) //nolint:forbidigo // CLI output + fmt.Println("would run: go mod tidy") //nolint:forbidigo // CLI output + return nil + } + + commands := []struct { + name string + args []string + }{ + { + name: "drop stretchr/testify require", + args: []string{"go", "mod", "edit", "-droprequire", "github.com/stretchr/testify"}, + }, + { + name: "add go-openapi/testify/v2 require", + args: []string{"go", "mod", "edit", "-require", "github.com/go-openapi/testify/v2@" + version}, + }, + { + name: "go mod tidy", + args: []string{"go", "mod", "tidy"}, + }, + } + + for _, c := range commands { + if verbose { + fmt.Printf("running: %s\n", strings.Join(c.args, " ")) //nolint:forbidigo // CLI output + } + + cmd := exec.CommandContext(context.Background(), c.args[0], c.args[1:]...) //nolint:gosec // args are constructed from controlled constants + cmd.Dir = dir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("%s: %w", c.name, err) + } + } + + return nil +} diff --git a/hack/migrate-testify/loader.go b/hack/migrate-testify/loader.go new file mode 100644 index 000000000..3af5df41d --- /dev/null +++ b/hack/migrate-testify/loader.go @@ -0,0 +1,171 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "path/filepath" + "strings" + + "golang.org/x/mod/modfile" + "golang.org/x/tools/go/packages" +) + +// loadPackages loads all Go packages under dir with full type information. +// If a go.work file is present, it loads packages from all workspace modules. +// Otherwise, it loads packages from the single module at dir. +func loadPackages(dir string) ([]*packages.Package, *token.FileSet, error) { + absDir, err := filepath.Abs(dir) + if err != nil { + return nil, nil, fmt.Errorf("resolving path: %w", err) + } + + patterns := workspacePatterns(absDir) + + fset := token.NewFileSet() + cfg := &packages.Config{ + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedSyntax | + packages.NeedTypes | + packages.NeedTypesInfo | + packages.NeedImports | + packages.NeedDeps, + Dir: absDir, + Fset: fset, + Tests: true, + } + + pkgs, err := packages.Load(cfg, patterns...) + if err != nil { + return nil, nil, fmt.Errorf("loading packages: %w", err) + } + + // Report loading errors but don't fail — partial results are useful. + var errs []string + for _, pkg := range pkgs { + for _, e := range pkg.Errors { + errs = append(errs, e.Error()) + } + } + if len(errs) > 0 { + // Log but continue — we can still transform files that loaded. + fmt.Printf("warning: %d package loading errors (some files may be skipped):\n", len(errs)) //nolint:forbidigo // CLI output + for _, e := range errs { + fmt.Printf(" %s\n", e) //nolint:forbidigo // CLI output + } + } + + return pkgs, fset, nil +} + +// workspacePatterns returns the load patterns for packages.Load. +// If go.work exists, it returns a pattern per workspace module (e.g. "./conv/...", "./mangling/..."). +// If go.work does not exist, it returns ["./..."] and warns if sub-modules are detected. +func workspacePatterns(absDir string) []string { + goworkPath := filepath.Join(absDir, "go.work") + data, err := os.ReadFile(goworkPath) + if err != nil { + // No go.work — check for sub-modules and warn. + warnIfSubModules(absDir) + return []string{"./..."} + } + + wf, err := modfile.ParseWork(goworkPath, data, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "warning: failed to parse go.work: %v; falling back to ./...\n", err) + return []string{"./..."} + } + + patterns := make([]string, 0, len(wf.Use)) + for _, use := range wf.Use { + dir := use.Path + if dir == "." { + patterns = append(patterns, "./...") + } else { + // Normalize: strip leading "./" if present, ensure trailing "/...". + dir = strings.TrimPrefix(dir, "./") + patterns = append(patterns, "./"+dir+"/...") + } + } + + if len(patterns) == 0 { + return []string{"./..."} + } + + fmt.Printf("go.work: loading %d workspace modules\n", len(patterns)) //nolint:forbidigo // CLI output + return patterns +} + +// warnIfSubModules scans for nested go.mod files and warns if found. +func warnIfSubModules(absDir string) { + var subModules []string + + _ = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil //nolint:nilerr // intentionally swallow walk errors for best-effort scan + } + if info.IsDir() { + base := filepath.Base(path) + if base == "vendor" || base == ".git" || base == "node_modules" { + return filepath.SkipDir + } + return nil + } + if info.Name() == "go.mod" && path != filepath.Join(absDir, "go.mod") { + rel, _ := filepath.Rel(absDir, filepath.Dir(path)) + subModules = append(subModules, rel) + } + return nil + }) + + if len(subModules) == 0 { + return + } + + fmt.Fprintf(os.Stderr, "warning: found %d sub-modules without go.work; pass 2 will only cover the root module\n", len(subModules)) + fmt.Fprintf(os.Stderr, " → Create go.work with: go work init . %s\n", strings.Join(prefixDot(subModules), " ")) +} + +// prefixDot adds "./" prefix to each path. +func prefixDot(paths []string) []string { + result := make([]string, len(paths)) + for i, p := range paths { + result[i] = "./" + p + } + return result +} + +// fileImportsPath reports whether the given file has an import with the specified path. +func fileImportsPath(f *ast.File, path string) bool { + for _, imp := range f.Imports { + if importPath(imp) == path { + return true + } + } + return false +} + +// importPath returns the unquoted import path from an ImportSpec. +func importPath(imp *ast.ImportSpec) string { + return strings.Trim(imp.Path.Value, `"`) +} + +// fileImportsAny reports whether the file imports any path with the given prefix. +func fileImportsAny(f *ast.File, prefix string) bool { + for _, imp := range f.Imports { + if strings.HasPrefix(importPath(imp), prefix) { + return true + } + } + return false +} + +// isVendorPath reports whether the file path is inside a vendor/ directory. +func isVendorPath(path string) bool { + return strings.Contains(path, "/vendor/") || strings.HasPrefix(path, "vendor/") +} diff --git a/hack/migrate-testify/main.go b/hack/migrate-testify/main.go new file mode 100644 index 000000000..245258412 --- /dev/null +++ b/hack/migrate-testify/main.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +// Package main implements a migration tool for converting stretchr/testify +// usage to go-openapi/testify/v2, and upgrading reflection-based assertions +// to generic variants where type information permits. +// +// Usage: +// +// go run ./hack/migrate-testify [flags] [directory] +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/exec" + "strings" +) + +type options struct { + migrate bool + upgradeGenerics bool + all bool + dryRun bool + verbose bool + skipGomod bool + skipVendor bool + version string +} + +func main() { + opts := &options{} + + flag.BoolVar(&opts.migrate, "migrate", false, "Run pass 1: stretchr/testify → go-openapi/testify/v2") + flag.BoolVar(&opts.upgradeGenerics, "upgrade-generics", false, "Run pass 2: reflection → generic assertions") + flag.BoolVar(&opts.all, "all", false, "Run both passes sequentially") + flag.BoolVar(&opts.dryRun, "dry-run", false, "Show diffs without modifying files") + flag.BoolVar(&opts.verbose, "verbose", false, "Print detailed transformation info") + flag.BoolVar(&opts.skipGomod, "skip-gomod", false, "Skip go.mod changes") + flag.BoolVar(&opts.skipVendor, "skip-vendor", true, "Skip vendor/ directory") + flag.StringVar(&opts.version, "version", "v2.3.0", "Target testify version") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: go run ./hack/migrate-testify [flags] [directory]\n\n") + fmt.Fprintf(os.Stderr, "Migrate stretchr/testify to go-openapi/testify/v2 and upgrade to generic assertions.\n\n") + fmt.Fprintf(os.Stderr, "Flags:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nAt least one of --migrate, --upgrade-generics, or --all is required.\n") + fmt.Fprintf(os.Stderr, "\nMono-repo support:\n") + fmt.Fprintf(os.Stderr, " Pass 1 walks the filesystem and works across module boundaries.\n") + fmt.Fprintf(os.Stderr, " Pass 2 requires type information and uses go/packages to load code.\n") + fmt.Fprintf(os.Stderr, " For multi-module repos, a go.work file must be present so that pass 2\n") + fmt.Fprintf(os.Stderr, " can load all workspace modules. Create one with:\n") + fmt.Fprintf(os.Stderr, " go work init . ./sub/module1 ./sub/module2 ...\n") + fmt.Fprintf(os.Stderr, "\nPost-migration checklist:\n") + fmt.Fprintf(os.Stderr, " - Run your linter: the migration may surface pre-existing unchecked linting issues.\n") + fmt.Fprintf(os.Stderr, " - Run your test suite to verify all tests still pass.\n") + } + + flag.Parse() + + if opts.all { + opts.migrate = true + opts.upgradeGenerics = true + } + + if !opts.migrate && !opts.upgradeGenerics { + flag.Usage() + os.Exit(2) //nolint:mnd // standard exit code for usage error + } + + dir := "." + if flag.NArg() > 0 { + dir = flag.Arg(0) + } + + // Pre-flight: warn if git is dirty. + checkGitDirty(dir) + + if opts.migrate { + fmt.Println("=== Pass 1: Migration (stretchr/testify → go-openapi/testify/v2) ===") //nolint:forbidigo // CLI output + if err := runMigration(dir, opts); err != nil { + fmt.Fprintf(os.Stderr, "error: migration: %v\n", err) + os.Exit(1) + } + } + + if opts.upgradeGenerics { + fmt.Println("=== Pass 2: Generic Upgrade (reflection → generic assertions) ===") //nolint:forbidigo // CLI output + if err := runGenericUpgrade(dir, opts); err != nil { + fmt.Fprintf(os.Stderr, "error: generic upgrade: %v\n", err) + os.Exit(1) + } + } + + fmt.Println("Done.") //nolint:forbidigo // CLI output +} + +// checkGitDirty warns if the working directory has uncommitted changes. +func checkGitDirty(dir string) { + cmd := exec.CommandContext(context.Background(), "git", "status", "--porcelain") + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + return // Not a git repo or git not available — skip check. + } + output := strings.TrimSpace(string(out)) + if output != "" { + fmt.Fprintf(os.Stderr, "warning: working directory has uncommitted changes\n") + } +} diff --git a/hack/migrate-testify/migrate.go b/hack/migrate-testify/migrate.go new file mode 100644 index 000000000..807ed90b2 --- /dev/null +++ b/hack/migrate-testify/migrate.go @@ -0,0 +1,311 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + + "golang.org/x/tools/go/ast/astutil" +) + +// runMigration executes pass 1: stretchr/testify → go-openapi/testify/v2. +func runMigration(dir string, opts *options) error { + fset := token.NewFileSet() + rpt := &report{} + + // Walk all .go files in the directory tree. + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + base := filepath.Base(path) + if base == "vendor" && opts.skipVendor { + return filepath.SkipDir + } + if base == ".git" || base == "node_modules" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + + rpt.filesScanned++ + return migrateFile(fset, path, opts, rpt) + }) + if err != nil { + return fmt.Errorf("walking directory: %w", err) + } + + rpt.print(opts.verbose) + rpt.printPass1Summary() + + if !opts.dryRun && !opts.skipGomod { + if err := updateGoMod(dir, opts.version, false, opts.verbose); err != nil { + return fmt.Errorf("updating go.mod: %w", err) + } + } else if opts.dryRun && !opts.skipGomod { + if err := updateGoMod(dir, opts.version, true, opts.verbose); err != nil { + return fmt.Errorf("updating go.mod: %w", err) + } + } + + return nil +} + +// migrateFile processes a single Go file for pass 1 transformations. +func migrateFile(fset *token.FileSet, filename string, opts *options, rpt *report) error { + src, err := os.ReadFile(filename) + if err != nil { + return err + } + + f, err := parser.ParseFile(fset, filename, src, parser.ParseComments) + if err != nil { + return fmt.Errorf("parsing %s: %w", filename, err) + } + + if !fileImportsAny(f, "github.com/stretchr/testify") { + return nil + } + + changed := false + + // 1. Detect incompatible imports. + for imp, msg := range incompatibleImports { + if fileImportsPath(f, imp) { + rpt.warn(filename, 0, msg) + } + } + + // 2. Build alias map before rewriting imports. + aliases := buildAliasMap(f) + + // 2b. Count API usage for the summary report. + countAPIUsage(f, aliases, rpt) + + // 3. Rewrite imports. + for old, replacement := range importRewrites { + if astutil.RewriteImport(fset, f, old, replacement) { + changed = true + rpt.totalChanges++ + if opts.verbose { + rpt.info(filename, 0, fmt.Sprintf("rewrote import %q → %q", old, replacement)) + } + } + } + + // 4. Rename functions and replace PanicTestFunc. + changes := renameFunctions(f, aliases, fset, rpt, filename, opts.verbose) + if changes > 0 { + changed = true + rpt.totalChanges += changes + } + + // 5. Replace PanicTestFunc type references. + ptfChanges := replacePanicTestFunc(f, aliases, fset, rpt, filename, opts.verbose) + if ptfChanges > 0 { + changed = true + rpt.totalChanges += ptfChanges + } + + // 6. Detect YAML usage and inject enable import. + if needsYAMLEnable(f, aliases) { + if !fileImportsPath(f, goopenapiYAMLEnable) { + astutil.AddNamedImport(fset, f, "_", goopenapiYAMLEnable) + changed = true + rpt.totalChanges++ + if opts.verbose { + rpt.info(filename, 0, "injected enable/yaml import") + } + } + } + + if !changed { + return nil + } + + rpt.filesChanged++ + + if opts.dryRun { + return showDiff(fset, f, filename) + } + + return writeFile(fset, f, filename) +} + +// buildAliasMap builds a map from import alias (or default package name) to import path +// for stretchr/testify packages. +func buildAliasMap(f *ast.File) map[string]string { + aliases := make(map[string]string) + for _, imp := range f.Imports { + path := importPath(imp) + if !strings.HasPrefix(path, "github.com/stretchr/testify") { + continue + } + + var localName string + if imp.Name != nil { + localName = imp.Name.Name + } else { + // Default: last path element. + parts := strings.Split(path, "/") + localName = parts[len(parts)-1] + } + if localName != "_" && localName != "." { + aliases[localName] = path + } + } + return aliases +} + +// renameFunctions walks the AST and renames function calls that changed names +// between stretchr/testify and go-openapi/testify/v2. +func renameFunctions(f *ast.File, aliases map[string]string, fset *token.FileSet, rpt *report, filename string, verbose bool) int { + changes := 0 + + ast.Inspect(f, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + // Handle pkg.Func() calls. + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + funcName := sel.Sel.Name + newName, exists := migrationRenames[funcName] + if !exists { + return true + } + + // Verify this is a call on a testify package. + if !isTestifySelector(sel, aliases) { + return true + } + + sel.Sel.Name = newName + changes++ + pos := fset.Position(call.Pos()) + if verbose { + rpt.info(filename, pos.Line, fmt.Sprintf("renamed %s → %s", funcName, newName)) + } + + return true + }) + + return changes +} + +// isTestifySelector checks if a selector expression refers to a stretchr/testify package. +func isTestifySelector(sel *ast.SelectorExpr, aliases map[string]string) bool { + ident, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + _, exists := aliases[ident.Name] + return exists +} + +// replacePanicTestFunc replaces PanicTestFunc type references with func(). +func replacePanicTestFunc(f *ast.File, aliases map[string]string, fset *token.FileSet, rpt *report, filename string, verbose bool) int { + changes := 0 + + astutil.Apply(f, func(c *astutil.Cursor) bool { + sel, ok := c.Node().(*ast.SelectorExpr) + if !ok { + return true + } + + if sel.Sel.Name != "PanicTestFunc" { + return true + } + + if !isTestifySelector(sel, aliases) { + return true + } + + // Replace with func() — an *ast.FuncType with no params and no results. + c.Replace(&ast.FuncType{ + Params: &ast.FieldList{}, + }) + changes++ + pos := fset.Position(sel.Pos()) + if verbose { + rpt.info(filename, pos.Line, "replaced PanicTestFunc with func()") + } + + return true + }, nil) + + return changes +} + +// countAPIUsage walks the AST and counts every testify assertion call for reporting. +func countAPIUsage(f *ast.File, aliases map[string]string, rpt *report) { + ast.Inspect(f, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if !isTestifySelector(sel, aliases) { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + rpt.trackAPIUsage(ident.Name + "." + sel.Sel.Name) + + return true + }) +} + +// needsYAMLEnable checks if any YAML assertion functions are called in the file. +func needsYAMLEnable(f *ast.File, aliases map[string]string) bool { + found := false + + ast.Inspect(f, func(n ast.Node) bool { + if found { + return false + } + + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if yamlFunctions[sel.Sel.Name] && isTestifySelector(sel, aliases) { + found = true + return false + } + + return true + }) + + return found +} diff --git a/hack/migrate-testify/migrate_test.go b/hack/migrate-testify/migrate_test.go new file mode 100644 index 000000000..963676b6e --- /dev/null +++ b/hack/migrate-testify/migrate_test.go @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "go/parser" + "go/printer" + "go/token" + "iter" + "os" + "slices" + "strings" + "testing" + + "golang.org/x/tools/go/ast/astutil" +) + +type migrateTestCase struct { + name string + input string + expected string + // warnContains, if non-empty, checks that a warning was emitted containing this string. + warnContains string +} + +func migrateTestCases() iter.Seq[migrateTestCase] { + return slices.Values([]migrateTestCase{ + { + name: "basic import rewrite", + input: readTestdata("migrate_basic/input.go.txt"), + expected: readTestdata("migrate_basic/expected.go.txt"), + }, + { + name: "yaml enable injection", + input: readTestdata("migrate_yaml/input.go.txt"), + expected: readTestdata("migrate_yaml/expected.go.txt"), + }, + { + name: "aliased import with rename", + input: readTestdata("migrate_alias/input.go.txt"), + expected: readTestdata("migrate_alias/expected.go.txt"), + }, + { + name: "incompatible imports warn", + input: readTestdata("migrate_incompatible/input.go.txt"), + warnContains: "mock package is not available", + }, + { + name: "PanicTestFunc replacement", + input: readTestdata("migrate_panic_func/input.go.txt"), + expected: readTestdata("migrate_panic_func/expected.go.txt"), + }, + }) +} + +func TestMigrateFile(t *testing.T) { + t.Parallel() + + for c := range migrateTestCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + runMigrateSubtest(t, c) + }) + } +} + +func runMigrateSubtest(t *testing.T, c migrateTestCase) { + t.Helper() + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", c.input, parser.ParseComments) + if err != nil { + t.Fatalf("parse: %v", err) + } + + rpt := &report{} + + // Build aliases before rewriting imports. + aliases := buildAliasMap(f) + + // Detect incompatible imports. + for imp, msg := range incompatibleImports { + if fileImportsPath(f, imp) { + rpt.warn("test.go", 0, msg) + } + } + + // Rewrite imports. + for old, replacement := range importRewrites { + astutil.RewriteImport(fset, f, old, replacement) + } + + // Rename functions. + renameFunctions(f, aliases, fset, rpt, "test.go", true) + + // Replace PanicTestFunc. + replacePanicTestFunc(f, aliases, fset, rpt, "test.go", true) + + // YAML injection. + if needsYAMLEnable(f, aliases) { + astutil.AddNamedImport(fset, f, "_", goopenapiYAMLEnable) + } + + // Check warnings. + if c.warnContains != "" { + found := false + for _, d := range rpt.diagnostics { + if d.kind == "warning" && strings.Contains(d.message, c.warnContains) { + found = true + break + } + } + if !found { + t.Errorf("expected warning containing %q, got: %v", c.warnContains, rpt.diagnostics) + } + return + } + + // Compare output. + var buf strings.Builder + cfg := &printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8} + if err := cfg.Fprint(&buf, fset, f); err != nil { + t.Fatalf("print: %v", err) + } + + got := strings.TrimSpace(buf.String()) + want := strings.TrimSpace(c.expected) + + if got != want { + t.Errorf("output mismatch:\n--- got ---\n%s\n--- want ---\n%s", got, want) + } +} + +func readTestdata(path string) string { + data, err := os.ReadFile("testdata/" + path) + if err != nil { + panic("reading testdata: " + err.Error()) + } + return string(bytes.ReplaceAll(data, []byte{'\r'}, []byte{})) // on windows, remove the \r from \n\r sequences +} diff --git a/hack/migrate-testify/rename_map.go b/hack/migrate-testify/rename_map.go new file mode 100644 index 000000000..f3d88afbf --- /dev/null +++ b/hack/migrate-testify/rename_map.go @@ -0,0 +1,282 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +// stretchrAssertPath and related constants define the import paths for +// stretchr/testify and their go-openapi/testify/v2 replacements. +const ( + stretchrAssertPath = "github.com/stretchr/testify/assert" + stretchrRequirePath = "github.com/stretchr/testify/require" + stretchrRootPath = "github.com/stretchr/testify" + stretchrMockPath = "github.com/stretchr/testify/mock" + stretchrSuitePath = "github.com/stretchr/testify/suite" + stretchrHTTPPath = "github.com/stretchr/testify/http" + + goopenapiAssertPath = "github.com/go-openapi/testify/v2/assert" + goopenapiRequirePath = "github.com/go-openapi/testify/v2/require" + goopenapiRootPath = "github.com/go-openapi/testify/v2" + goopenapiYAMLEnable = "github.com/go-openapi/testify/v2/enable/yaml" +) + +// importRewrites maps stretchr import paths to go-openapi replacements. +var importRewrites = map[string]string{ //nolint:gochecknoglobals // lookup table + stretchrAssertPath: goopenapiAssertPath, + stretchrRequirePath: goopenapiRequirePath, + stretchrRootPath: goopenapiRootPath, +} + +// incompatibleImports are stretchr import paths with no go-openapi equivalent. +var incompatibleImports = map[string]string{ //nolint:gochecknoglobals // lookup table + stretchrMockPath: "mock package is not available in go-openapi/testify/v2.\n" + + " → Use github.com/vektra/mockery or hand-written mocks.\n" + + " → See: https://go-openapi.github.io/testify/usage/migration/index.html#5-remove-suitemock-usage", + stretchrSuitePath: "suite package is not available in go-openapi/testify/v2.\n" + + " → Use standard Go subtests (t.Run) and TestMain.\n" + + " → See: https://go-openapi.github.io/testify/usage/migration/index.html#5-remove-suitemock-usage", + stretchrHTTPPath: "http package is not available in go-openapi/testify/v2.\n" + + " → Use net/http/httptest from the standard library.\n" + + " → See: https://go-openapi.github.io/testify/usage/migration/index.html#6-remove-use-of-the-testifyhttp-package", +} + +// migrationRenames maps stretchr function names to their go-openapi equivalents +// where the names differ. +var migrationRenames = map[string]string{ //nolint:gochecknoglobals // lookup table + "EventuallyWithT": "EventuallyWith", + "EventuallyWithTf": "EventuallyWithf", + "NoDirExists": "DirNotExists", + "NoDirExistsf": "DirNotExistsf", + "NoFileExists": "FileNotExists", + "NoFileExistsf": "FileNotExistsf", +} + +// yamlFunctions lists assertion function names that use YAML, triggering +// injection of the enable/yaml import. +var yamlFunctions = map[string]bool{ //nolint:gochecknoglobals // lookup table + "YAMLEq": true, + "YAMLEqf": true, + "YAMLEqT": true, + "YAMLEqTf": true, + "YAMLEqBytes": true, + "YAMLEqBytesf": true, +} + +// constraintKind classifies the type constraint needed for a generic upgrade. +type constraintKind int + +const ( + constraintComparable constraintKind = iota + constraintDeepComparable // comparable AND == has same semantics as reflect.DeepEqual + constraintOrdered + constraintText + constraintSignedNumeric + constraintMeasurable + constraintPointer + constraintBoolean + constraintRegExp +) + +// containerKind classifies the container type for Contains upgrades. +type containerKind int + +const ( + containerString containerKind = iota + containerSlice + containerMap +) + +// upgradeRule defines how to upgrade a reflection-based assertion to a generic variant. +type upgradeRule struct { + // target is the generic function name to upgrade to. + target string + // argConstraints defines the constraint required for each argument + // (excluding t and msgAndArgs). For single-constraint functions, only + // one entry is needed. + argConstraints []constraintKind + // sameType requires that the relevant arguments have identical types. + sameType bool + // containerUpgrade means the function dispatches to different generic + // variants depending on the container type (string, slice, map). + containerUpgrade bool + // manualReview flags the upgrade as requiring manual review due to + // signature changes (e.g., IsType → IsOfTypeT changes arg count). + manualReview bool +} + +// genericUpgrades maps reflection-based assertion names to their generic upgrade rules. +var genericUpgrades = map[string]upgradeRule{ //nolint:gochecknoglobals // lookup table + // Equality — must be deep-comparable (== same as reflect.DeepEqual) + // to preserve semantics. Pointer types are excluded because EqualT + // uses == (address comparison) while Equal uses reflect.DeepEqual. + "Equal": { + target: "EqualT", + argConstraints: []constraintKind{constraintDeepComparable}, + sameType: true, + }, + "NotEqual": { + target: "NotEqualT", + argConstraints: []constraintKind{constraintDeepComparable}, + sameType: true, + }, + + // Comparison / ordering + "Greater": { + target: "GreaterT", + argConstraints: []constraintKind{constraintOrdered}, + sameType: true, + }, + "GreaterOrEqual": { + target: "GreaterOrEqualT", + argConstraints: []constraintKind{constraintOrdered}, + sameType: true, + }, + "Less": { + target: "LessT", + argConstraints: []constraintKind{constraintOrdered}, + sameType: true, + }, + "LessOrEqual": { + target: "LessOrEqualT", + argConstraints: []constraintKind{constraintOrdered}, + sameType: true, + }, + + // Numeric + "Positive": { + target: "PositiveT", + argConstraints: []constraintKind{constraintSignedNumeric}, + }, + "Negative": { + target: "NegativeT", + argConstraints: []constraintKind{constraintSignedNumeric}, + }, + "InDelta": { + target: "InDeltaT", + argConstraints: []constraintKind{constraintMeasurable}, + sameType: true, + }, + "InEpsilon": { + target: "InEpsilonT", + argConstraints: []constraintKind{constraintMeasurable}, + sameType: true, + }, + + // Container (dispatches to different targets) + "Contains": { + containerUpgrade: true, + }, + "NotContains": { + containerUpgrade: true, + }, + + // Collection + "ElementsMatch": { + target: "ElementsMatchT", + argConstraints: []constraintKind{constraintComparable}, + sameType: true, + }, + "Subset": { + target: "SliceSubsetT", + argConstraints: []constraintKind{constraintComparable}, + sameType: true, + }, + + // Ordering (slice) + "IsIncreasing": { + target: "IsIncreasingT", + argConstraints: []constraintKind{constraintOrdered}, + }, + "IsDecreasing": { + target: "IsDecreasingT", + argConstraints: []constraintKind{constraintOrdered}, + }, + "IsNonIncreasing": { + target: "IsNonIncreasingT", + argConstraints: []constraintKind{constraintOrdered}, + }, + "IsNonDecreasing": { + target: "IsNonDecreasingT", + argConstraints: []constraintKind{constraintOrdered}, + }, + + // String / regex + "Regexp": { + target: "RegexpT", + argConstraints: []constraintKind{constraintRegExp, constraintText}, + }, + "NotRegexp": { + target: "NotRegexpT", + argConstraints: []constraintKind{constraintRegExp, constraintText}, + }, + + // JSON / YAML + "JSONEq": { + target: "JSONEqT", + argConstraints: []constraintKind{constraintText}, + sameType: true, + }, + "YAMLEq": { + target: "YAMLEqT", + argConstraints: []constraintKind{constraintText}, + sameType: true, + }, + + // Pointer + "Same": { + target: "SameT", + argConstraints: []constraintKind{constraintPointer}, + sameType: true, + }, + "NotSame": { + target: "NotSameT", + argConstraints: []constraintKind{constraintPointer}, + sameType: true, + }, + + // Type (manual review — arg count changes) + "IsType": { + target: "IsOfTypeT", + manualReview: true, + argConstraints: []constraintKind{}, + }, + + // Boolean + "True": { + target: "TrueT", + argConstraints: []constraintKind{constraintBoolean}, + }, + "False": { + target: "FalseT", + argConstraints: []constraintKind{constraintBoolean}, + }, +} + +// containerUpgradeTargets maps container kind to (funcName, notFuncName) pairs. +var containerUpgradeTargets = map[containerKind][2]string{ //nolint:gochecknoglobals // lookup table + containerString: {"StringContainsT", "StringNotContainsT"}, + containerSlice: {"SliceContainsT", "SliceNotContainsT"}, + containerMap: {"MapContainsT", "MapNotContainsT"}, +} + +// skipReason describes why a generic upgrade was skipped for an assertion call. +type skipReason string + +const ( + skipPointerSemantics skipReason = "pointer type (== compares addresses, not values)" + skipInterfaceField skipReason = "struct contains pointer/interface fields" + skipUnresolvableType skipReason = "type not statically resolvable" + skipInterfaceType skipReason = "argument is interface{}/any" + skipTypeMismatch skipReason = "arguments have different types" + skipNotOrdered skipReason = "type does not satisfy Ordered constraint" + skipNotText skipReason = "type does not satisfy Text constraint" + skipNotSignedNumeric skipReason = "type does not satisfy SignedNumeric constraint" + skipNotMeasurable skipReason = "type does not satisfy Measurable constraint" + skipNotComparable skipReason = "type does not satisfy comparable constraint" + skipNotBoolean skipReason = "type does not satisfy Boolean constraint" + skipNotPointer skipReason = "type is not a pointer" + skipNotSlice skipReason = "argument is not a slice type" + skipNotRegExp skipReason = "type does not satisfy RegExp constraint" + skipSliceElemNotDeepComparable skipReason = "slice element not deeply comparable" + skipSliceElemNotOrdered skipReason = "slice element not ordered" + skipContainerTypeUnknown skipReason = "container is not string, slice, or map" +) diff --git a/hack/migrate-testify/report.go b/hack/migrate-testify/report.go new file mode 100644 index 000000000..3879ffaf6 --- /dev/null +++ b/hack/migrate-testify/report.go @@ -0,0 +1,249 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "cmp" + "fmt" + "go/ast" + "go/printer" + "go/token" + "os" + "slices" + "strings" +) + +// diagnostic represents a single warning or info message from the migration tool. +type diagnostic struct { + file string + line int + message string + kind string // "warning", "info", "error" +} + +func (d diagnostic) String() string { + if d.line > 0 { + return fmt.Sprintf("%s:%d: %s: %s", d.file, d.line, d.kind, d.message) + } + return fmt.Sprintf("%s: %s: %s", d.file, d.kind, d.message) +} + +// report collects diagnostics and file changes during a migration run. +type report struct { + diagnostics []diagnostic + filesChanged int + totalChanges int + filesScanned int + // Pass 1 stats: funcName → call count (e.g. "assert.Equal" → 42). + apiUsage map[string]int + // Pass 2 stats: "Equal→EqualT" → count. + upgraded map[string]int + // Pass 2 stats: skipReason → count. + skipped map[string]int +} + +func (r *report) warn(file string, line int, msg string) { + r.diagnostics = append(r.diagnostics, diagnostic{file: file, line: line, message: msg, kind: "warning"}) +} + +func (r *report) info(file string, line int, msg string) { + r.diagnostics = append(r.diagnostics, diagnostic{file: file, line: line, message: msg, kind: "info"}) +} + +func (r *report) errorf(file string, line int, msg string) { + r.diagnostics = append(r.diagnostics, diagnostic{file: file, line: line, message: msg, kind: "error"}) +} + +// trackAPIUsage increments the usage counter for a testify API call. +func (r *report) trackAPIUsage(qualifiedName string) { + if r.apiUsage == nil { + r.apiUsage = make(map[string]int) + } + r.apiUsage[qualifiedName]++ +} + +// trackUpgrade increments the upgrade counter for a successful generic upgrade. +func (r *report) trackUpgrade(from, to string) { + if r.upgraded == nil { + r.upgraded = make(map[string]int) + } + r.upgraded[from+" → "+to]++ +} + +// trackSkip records a skipped generic upgrade and emits an info diagnostic. +func (r *report) trackSkip(file string, line int, funcName string, reason skipReason, verbose bool, typeInfo string) { + if r.skipped == nil { + r.skipped = make(map[string]int) + } + r.skipped[string(reason)]++ + if verbose { + msg := fmt.Sprintf("skipped %s: %s", funcName, reason) + if typeInfo != "" { + msg += " [" + typeInfo + "]" + } + r.info(file, line, msg) + } +} + +// print outputs all collected diagnostics. +func (r *report) print(verbose bool) { + if verbose { + for _, d := range r.diagnostics { + fmt.Println(d) //nolint:forbidigo // CLI output + } + } else { + // Only print warnings and errors, not info diagnostics. + for _, d := range r.diagnostics { + if d.kind != "info" { + fmt.Println(d) //nolint:forbidigo // CLI output + } + } + } +} + +// printPass1Summary outputs the Pass 1 structured summary. +func (r *report) printPass1Summary() { + fmt.Printf("\n=== Pass 1 Summary ===\n") //nolint:forbidigo // CLI output + fmt.Printf("Files scanned: %d | Files changed: %d | Transformations: %d\n", //nolint:forbidigo // CLI output + r.filesScanned, r.filesChanged, r.totalChanges) + + if len(r.apiUsage) > 0 { + fmt.Printf("\nAPI usage across migrated scope:\n") //nolint:forbidigo // CLI output + printCountTable(r.apiUsage) + } + + warnings := 0 + for _, d := range r.diagnostics { + if d.kind == "warning" { + warnings++ + } + } + fmt.Printf("\nWarnings: %d\n", warnings) //nolint:forbidigo // CLI output +} + +// printPass2Summary outputs the Pass 2 structured summary. +func (r *report) printPass2Summary() { + fmt.Printf("\n=== Pass 2 Summary ===\n") //nolint:forbidigo // CLI output + fmt.Printf("Files scanned: %d | Files changed: %d | Upgrades: %d\n", //nolint:forbidigo // CLI output + r.filesScanned, r.filesChanged, r.totalChanges) + + if len(r.upgraded) > 0 { + fmt.Printf("\nUpgraded assertions:\n") //nolint:forbidigo // CLI output + printCountTable(r.upgraded) + } + + if len(r.skipped) > 0 { + total := 0 + for _, v := range r.skipped { + total += v + } + fmt.Printf("\nSkipped (generic alternative exists but cannot upgrade): %d\n", total) //nolint:forbidigo // CLI output + printCountTable(r.skipped) + } +} + +// printCountTable prints a map of label→count as a right-aligned table with two columns. +func printCountTable(m map[string]int) { + type entry struct { + label string + count int + } + entries := make([]entry, 0, len(m)) + for k, v := range m { + entries = append(entries, entry{k, v}) + } + slices.SortFunc(entries, func(a, b entry) int { + if c := cmp.Compare(b.count, a.count); c != 0 { + return c + } + return cmp.Compare(a.label, b.label) + }) + + // Find max label width for alignment. + maxLabel := 0 + for _, e := range entries { + if len(e.label) > maxLabel { + maxLabel = len(e.label) + } + } + + // Print in two columns if there are enough entries. + for i := 0; i < len(entries); i += 2 { + left := entries[i] + if i+1 < len(entries) { + right := entries[i+1] + fmt.Printf(" %-*s %5d %-*s %5d\n", maxLabel, left.label, left.count, maxLabel, right.label, right.count) //nolint:forbidigo // CLI output + } else { + fmt.Printf(" %-*s %5d\n", maxLabel, left.label, left.count) //nolint:forbidigo // CLI output + } + } +} + +// writeFile writes a modified AST back to disk via go/printer. +func writeFile(fset *token.FileSet, f *ast.File, filename string) error { + out, err := os.Create(filename) + if err != nil { + return fmt.Errorf("creating %s: %w", filename, err) + } + defer out.Close() + + cfg := &printer.Config{ + Mode: printer.UseSpaces | printer.TabIndent, + Tabwidth: 8, //nolint:mnd // standard Go tabwidth + } + if err := cfg.Fprint(out, fset, f); err != nil { + return fmt.Errorf("writing %s: %w", filename, err) + } + return nil +} + +// showDiff displays a simple diff-like output showing what would change. +func showDiff(fset *token.FileSet, f *ast.File, filename string) error { + original, err := os.ReadFile(filename) + if err != nil { + return err + } + + var buf strings.Builder + cfg := &printer.Config{ + Mode: printer.UseSpaces | printer.TabIndent, + Tabwidth: 8, //nolint:mnd // standard Go tabwidth + } + if err := cfg.Fprint(&buf, fset, f); err != nil { + return err + } + + modified := buf.String() + if string(original) == modified { + return nil + } + + fmt.Printf("--- %s\n+++ %s (modified)\n", filename, filename) //nolint:forbidigo // CLI output + + origLines := strings.Split(string(original), "\n") + modLines := strings.Split(modified, "\n") + + // Simple line-by-line diff — not a real unified diff, but helpful for dry-run. + maxLines := max(len(origLines), len(modLines)) + + for i := range maxLines { + var origLine, modLine string + if i < len(origLines) { + origLine = origLines[i] + } + if i < len(modLines) { + modLine = modLines[i] + } + if origLine != modLine { + if origLine != "" { + fmt.Printf("-%s\n", origLine) //nolint:forbidigo // CLI output + } + if modLine != "" { + fmt.Printf("+%s\n", modLine) //nolint:forbidigo // CLI output + } + } + } + fmt.Println() //nolint:forbidigo // CLI output + return nil +} diff --git a/hack/migrate-testify/report_test.go b/hack/migrate-testify/report_test.go new file mode 100644 index 000000000..1dc3d9a5a --- /dev/null +++ b/hack/migrate-testify/report_test.go @@ -0,0 +1,227 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "iter" + "slices" + "testing" +) + +type trackAPIUsageCase struct { + name string + calls []string + expected map[string]int +} + +func trackAPIUsageCases() iter.Seq[trackAPIUsageCase] { + return slices.Values([]trackAPIUsageCase{ + { + name: "single call", + calls: []string{"assert.Equal"}, + expected: map[string]int{ + "assert.Equal": 1, + }, + }, + { + name: "multiple calls same function", + calls: []string{"assert.Equal", "assert.Equal", "assert.Equal"}, + expected: map[string]int{ + "assert.Equal": 3, + }, + }, + { + name: "multiple different functions", + calls: []string{"assert.Equal", "require.NoError", "assert.True"}, + expected: map[string]int{ + "assert.Equal": 1, + "require.NoError": 1, + "assert.True": 1, + }, + }, + }) +} + +func TestTrackAPIUsage(t *testing.T) { + t.Parallel() + + for c := range trackAPIUsageCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + rpt := &report{} + for _, call := range c.calls { + rpt.trackAPIUsage(call) + } + + if len(rpt.apiUsage) != len(c.expected) { + t.Errorf("expected %d entries, got %d", len(c.expected), len(rpt.apiUsage)) + } + for k, v := range c.expected { + if rpt.apiUsage[k] != v { + t.Errorf("expected %s=%d, got %d", k, v, rpt.apiUsage[k]) + } + } + }) + } +} + +type trackUpgradeCase struct { + name string + upgrades [][2]string // from, to pairs + expected map[string]int +} + +func trackUpgradeCases() iter.Seq[trackUpgradeCase] { + return slices.Values([]trackUpgradeCase{ + { + name: "single upgrade", + upgrades: [][2]string{{"Equal", "EqualT"}}, + expected: map[string]int{ + "Equal → EqualT": 1, + }, + }, + { + name: "multiple upgrades", + upgrades: [][2]string{ + {"Equal", "EqualT"}, + {"Equal", "EqualT"}, + {"True", "TrueT"}, + }, + expected: map[string]int{ + "Equal → EqualT": 2, + "True → TrueT": 1, + }, + }, + }) +} + +func TestTrackUpgrade(t *testing.T) { + t.Parallel() + + for c := range trackUpgradeCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + rpt := &report{} + for _, u := range c.upgrades { + rpt.trackUpgrade(u[0], u[1]) + } + + if len(rpt.upgraded) != len(c.expected) { + t.Errorf("expected %d entries, got %d", len(c.expected), len(rpt.upgraded)) + } + for k, v := range c.expected { + if rpt.upgraded[k] != v { + t.Errorf("expected %s=%d, got %d", k, v, rpt.upgraded[k]) + } + } + }) + } +} + +type trackSkipCase struct { + name string + skips []skipReason + expectedSkipped map[string]int + verbose bool + expectDiag bool +} + +func trackSkipCases() iter.Seq[trackSkipCase] { + return slices.Values([]trackSkipCase{ + { + name: "single skip", + skips: []skipReason{skipPointerSemantics}, + expectedSkipped: map[string]int{ + string(skipPointerSemantics): 1, + }, + verbose: false, + expectDiag: false, + }, + { + name: "verbose emits diagnostic", + skips: []skipReason{skipInterfaceType}, + expectedSkipped: map[string]int{ + string(skipInterfaceType): 1, + }, + verbose: true, + expectDiag: true, + }, + { + name: "multiple skips aggregated", + skips: []skipReason{skipPointerSemantics, skipPointerSemantics, skipTypeMismatch}, + expectedSkipped: map[string]int{ + string(skipPointerSemantics): 2, + string(skipTypeMismatch): 1, + }, + verbose: false, + expectDiag: false, + }, + }) +} + +func TestTrackSkip(t *testing.T) { + t.Parallel() + + for c := range trackSkipCases() { + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + rpt := &report{} + for _, reason := range c.skips { + rpt.trackSkip("test.go", 42, "Equal", reason, c.verbose, "*float64") + } + + if len(rpt.skipped) != len(c.expectedSkipped) { + t.Errorf("expected %d skip entries, got %d", len(c.expectedSkipped), len(rpt.skipped)) + } + for k, v := range c.expectedSkipped { + if rpt.skipped[k] != v { + t.Errorf("expected %s=%d, got %d", k, v, rpt.skipped[k]) + } + } + + hasDiag := len(rpt.diagnostics) > 0 + if c.expectDiag && !hasDiag { + t.Error("expected diagnostic to be emitted in verbose mode") + } + if !c.expectDiag && hasDiag { + t.Errorf("expected no diagnostic in non-verbose mode, got %d", len(rpt.diagnostics)) + } + }) + } +} + +func TestReportInitializesLazily(t *testing.T) { + t.Parallel() + + rpt := &report{} + + // All maps should be nil initially. + if rpt.apiUsage != nil { + t.Error("apiUsage should be nil initially") + } + if rpt.upgraded != nil { + t.Error("upgraded should be nil initially") + } + if rpt.skipped != nil { + t.Error("skipped should be nil initially") + } + + // After tracking, maps should be initialized. + rpt.trackAPIUsage("assert.Equal") + rpt.trackUpgrade("Equal", "EqualT") + rpt.trackSkip("test.go", 1, "Equal", skipPointerSemantics, false, "") + + if rpt.apiUsage == nil { + t.Error("apiUsage should be initialized after tracking") + } + if rpt.upgraded == nil { + t.Error("upgraded should be initialized after tracking") + } + if rpt.skipped == nil { + t.Error("skipped should be initialized after tracking") + } +} diff --git a/hack/migrate-testify/testdata/migrate_alias/expected.go.txt b/hack/migrate-testify/testdata/migrate_alias/expected.go.txt new file mode 100644 index 000000000..5dcca169b --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_alias/expected.go.txt @@ -0,0 +1,12 @@ +package example + +import ( + "testing" + + ta "github.com/go-openapi/testify/v2/assert" +) + +func TestAlias(t *testing.T) { + ta.Equal(t, 1, 1) + ta.DirNotExists(t, "/nonexistent") +} diff --git a/hack/migrate-testify/testdata/migrate_alias/input.go.txt b/hack/migrate-testify/testdata/migrate_alias/input.go.txt new file mode 100644 index 000000000..0d8d2975c --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_alias/input.go.txt @@ -0,0 +1,12 @@ +package example + +import ( + "testing" + + ta "github.com/stretchr/testify/assert" +) + +func TestAlias(t *testing.T) { + ta.Equal(t, 1, 1) + ta.NoDirExists(t, "/nonexistent") +} diff --git a/hack/migrate-testify/testdata/migrate_basic/expected.go.txt b/hack/migrate-testify/testdata/migrate_basic/expected.go.txt new file mode 100644 index 000000000..a76d2d6ee --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_basic/expected.go.txt @@ -0,0 +1,13 @@ +package example + +import ( + "testing" + + "github.com/go-openapi/testify/v2/assert" + "github.com/go-openapi/testify/v2/require" +) + +func TestBasic(t *testing.T) { + assert.Equal(t, 1, 1) + require.NoError(t, nil) +} diff --git a/hack/migrate-testify/testdata/migrate_basic/input.go.txt b/hack/migrate-testify/testdata/migrate_basic/input.go.txt new file mode 100644 index 000000000..633d96a92 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_basic/input.go.txt @@ -0,0 +1,13 @@ +package example + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBasic(t *testing.T) { + assert.Equal(t, 1, 1) + require.NoError(t, nil) +} diff --git a/hack/migrate-testify/testdata/migrate_incompatible/input.go.txt b/hack/migrate-testify/testdata/migrate_incompatible/input.go.txt new file mode 100644 index 000000000..3ebc594f1 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_incompatible/input.go.txt @@ -0,0 +1,21 @@ +package example + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +type MySuite struct { + suite.Suite +} + +type MyMock struct { + mock.Mock +} + +func TestIncompat(t *testing.T) { + assert.True(t, true) +} diff --git a/hack/migrate-testify/testdata/migrate_panic_func/expected.go.txt b/hack/migrate-testify/testdata/migrate_panic_func/expected.go.txt new file mode 100644 index 000000000..3b9b7fb96 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_panic_func/expected.go.txt @@ -0,0 +1,12 @@ +package example + +import ( + "testing" + + "github.com/go-openapi/testify/v2/assert" +) + +func TestPanicFunc(t *testing.T) { + var f func() = func() { panic("boom") } + assert.Panics(t, f) +} diff --git a/hack/migrate-testify/testdata/migrate_panic_func/input.go.txt b/hack/migrate-testify/testdata/migrate_panic_func/input.go.txt new file mode 100644 index 000000000..bdf8c8168 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_panic_func/input.go.txt @@ -0,0 +1,12 @@ +package example + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPanicFunc(t *testing.T) { + var f assert.PanicTestFunc = func() { panic("boom") } + assert.Panics(t, f) +} diff --git a/hack/migrate-testify/testdata/migrate_yaml/expected.go.txt b/hack/migrate-testify/testdata/migrate_yaml/expected.go.txt new file mode 100644 index 000000000..f1fae7d56 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_yaml/expected.go.txt @@ -0,0 +1,12 @@ +package example + +import ( + "testing" + + "github.com/go-openapi/testify/v2/assert" + _ "github.com/go-openapi/testify/v2/enable/yaml" +) + +func TestYAML(t *testing.T) { + assert.YAMLEq(t, `name: foo`, `name: foo`) +} diff --git a/hack/migrate-testify/testdata/migrate_yaml/input.go.txt b/hack/migrate-testify/testdata/migrate_yaml/input.go.txt new file mode 100644 index 000000000..75dd683f3 --- /dev/null +++ b/hack/migrate-testify/testdata/migrate_yaml/input.go.txt @@ -0,0 +1,11 @@ +package example + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestYAML(t *testing.T) { + assert.YAMLEq(t, `name: foo`, `name: foo`) +}