From f0eb2e1cfc85f0d6038a42ff04ec7b33e95c8e06 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Sat, 12 Apr 2025 12:34:47 +1000 Subject: [PATCH] chore: go2proto transform functions --- .vscode/settings.json | 1 + cmd/go2proto/main.go | 26 ++- cmd/go2proto/transform/transform.go | 267 +++++++++++++++++++++++ cmd/go2proto/transform/transform_test.go | 64 ++++++ cmd/go2proto/transform/type.go | 61 ++++++ cmd/go2proto/transform/type_test.go | 24 ++ cmd/go2proto/transform/typesgo.go | 180 +++++++++++++++ 7 files changed, 615 insertions(+), 8 deletions(-) create mode 100644 cmd/go2proto/transform/transform.go create mode 100644 cmd/go2proto/transform/transform_test.go create mode 100644 cmd/go2proto/transform/type.go create mode 100644 cmd/go2proto/transform/type_test.go create mode 100644 cmd/go2proto/transform/typesgo.go diff --git a/.vscode/settings.json b/.vscode/settings.json index a546020961..fc0014cd83 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -20,5 +20,6 @@ "go.testFlags": [ "-race" ], + "ftl.automaticallyStartServer": "never", // "go.buildTags": "infrastructure" } \ No newline at end of file diff --git a/cmd/go2proto/main.go b/cmd/go2proto/main.go index 8d4a7d6866..68d4609732 100644 --- a/cmd/go2proto/main.go +++ b/cmd/go2proto/main.go @@ -1223,24 +1223,34 @@ func parsePBTag(tag string) (pbTag, bool, error) { return out, true, nil } +var fset = token.NewFileSet() + func loadInterface(pkg, symbol string) *types.Interface { + name := loadObject(pkg, symbol) + if t, ok := name.(*types.TypeName); ok { + if t.Name() == symbol { + return t.Type().Underlying().(*types.Interface) //nolint:forcetypeassert + } + } + panic("could not find " + pkg + "." + symbol) +} + +func loadObject(pkgName, symbol string) types.Object { pkgs, err := packages.Load(&packages.Config{ + Fset: fset, Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedSyntax | packages.NeedFiles | packages.NeedName, - }, pkg) + }, pkgName) if err != nil { panic(err) } for _, pkg := range pkgs { - for _, name := range pkg.TypesInfo.Defs { - if t, ok := name.(*types.TypeName); ok { - if t.Name() == symbol { - return t.Type().Underlying().(*types.Interface) //nolint:forcetypeassert - } - } + obj := pkg.Types.Scope().Lookup(symbol) + if obj != nil { + return obj } } - panic("could not find " + pkg + "." + symbol) + panic("could not find " + pkgName + "." + symbol) } // PackageDirectives captures the directives in the protobuf:XYZ directives extracted from package comments. diff --git a/cmd/go2proto/transform/transform.go b/cmd/go2proto/transform/transform.go new file mode 100644 index 0000000000..8b98b7fe80 --- /dev/null +++ b/cmd/go2proto/transform/transform.go @@ -0,0 +1,267 @@ +package transform + +import ( + "fmt" + "go/token" + "go/types" + + "github.com/alecthomas/types/optional" + "golang.org/x/tools/go/packages" +) + +var ( + fset = token.NewFileSet() + textMarshaler = loadInterface("encoding", "TextMarshaler") + binaryMarshaler = loadInterface("encoding", "BinaryMarshaler") +) + +// An Expr represents an expression of the given type. +type Expr interface { + // Type of the Expr wrapped in result.Result[T] + Type() types.Type + Expr() string +} + +type Variable struct { + Name string + Typ types.Type +} + +var _ Expr = Variable{} + +func Var(name string, typ types.Type) Variable { return Variable{name, typ} } +func (v Variable) Type() types.Type { return v.Typ } +func (v Variable) Expr() string { return fmt.Sprintf("result.Ok(%s)", v.Name) } + +type Call struct { + Package optional.Option[string] + Name string + Arg Expr + Out types.Type +} + +var _ Expr = Call{} + +func (c Call) Type() types.Type { return c.Out } +func (c Call) Expr() string { + name := "" + if pkg, ok := c.Package.Get(); ok { + name += pkg + "." + } + name += c.Name + return fmt.Sprintf("result.Map(%s, %s)", c.Arg.Expr(), name) +} + +// basicType looks up a basic type by name. +func basicType(name string) types.Type { + return types.Universe.Lookup(name).(*types.TypeName).Type() +} + +func Transform(from Expr, to types.Type) (Expr, bool) { + transforms := findTransform(from.Type(), to) + if len(transforms) == 0 { + return nil, false + } + out := from + for _, transformation := range transforms { + out = transformation.Transform(out) + } + return out, true +} + +// Depth-first search for full match. +func findTransform(from types.Type, to types.Type) []*Transformation { + for _, probe := range probes { + if tf := probe(from); tf != nil && types.Identical(tf.To, to) { + return []*Transformation{tf} + } else if tf != nil { + children := findTransform(tf.To, to) + if len(children) != 0 { + return append([]*Transformation{tf}, children...) + } + } + } + return nil +} + +type Priority int + +const ( + LowPriority Priority = iota - 1 + MediumPriority + HighPriority +) + +// Probe function to determine if the transformation can be applied to the given type. +type Probe func(from types.Type) *Transformation + +type Transformation struct { + To types.Type + Priority Priority + Transform func(from Expr) Expr + Imports []string + Helper string +} + +var probes = []Probe{ + // T -> P using the method "T.ToProto() (P, error)" + // func(from types.Type) *Transformation { + // named, ok := from.(*types.Named) + // if !ok { + // return nil + // } + // var toProto *types.Func + // for method := range named.Methods() { + // if method.Name() == "ToProto" && method.Signature().Results().Len() == 1 { + // toProto = method + // break + // } + // } + // if toProto == nil { + // return nil + // } + // result := toProto.Signature().Results().At(0) + // return &Transformation{ + // To: result.Type(), + // Priority: LowPriority, + // Transform: func(from Expr) Expr { + // return Call{ + // Name: "toProto", + // Arg: from, + // Out: result.Type(), + // } + // }, + // Helper: ` + // func toProto[P, T interface { ToProto() P }](v T) (P, error) { + // return v.ToProto() + // } + // `, + // } + // }, + // []byte -> string + func(from types.Type) *Transformation { + if !types.Identical(from, types.NewSlice(basicType("byte"))) { + return nil + } + return &Transformation{ + To: basicType("string"), + Transform: func(from Expr) Expr { + return Call{ + Name: "bytesToString", + Arg: from, + Out: types.NewSlice(basicType("byte")), + } + }, + Helper: ` + func bytesToString(v []byte) (string, error) { + return string(v), nil + } + `, + } + }, + // string -> []byte + func(from types.Type) *Transformation { + if !types.Identical(from, basicType("string")) { + return nil + } + return &Transformation{ + To: types.NewSlice(basicType("byte")), + Transform: func(from Expr) Expr { + return Call{ + Name: "stringToBytes", + Arg: from, + Out: basicType("string"), + } + }, + Helper: ` + func stringToBytes(v string) ([]byte, error) { + return []byte(v), nil + } + `, + } + }, + // encoding.BinaryMarshaler -> []byte + func(from types.Type) *Transformation { + if !implements(from, binaryMarshaler) { + return nil + } + return &Transformation{ + To: types.NewSlice(basicType("byte")), + Transform: func(from Expr) Expr { + return Call{ + Name: "marshalBinary", + Arg: from, + Out: types.NewSlice(basicType("byte")), + } + }, + Imports: []string{ + "encoding", + "github.com/alecthomas/types/result", + }, + Helper: ` + func marshalBinary(v encoding.BinaryMarshaler) ([]byte, error) { + return return v.MarshalBinary() + } + `, + } + }, + // optional.Option[T] -> T + func(from types.Type) *Transformation { + named, ok := from.(*types.Named) + if !ok { + return nil + } + obj := named.Obj() + if obj.Pkg().Path() != "github.com/alecthomas/types/optional" && obj.Name() == "Optional" { + return nil + } + return &Transformation{ + To: named.TypeArgs().At(0), + Transform: func(from Expr) Expr { + return Call{ + Name: "unwrapOptional", + Arg: from, + Out: from.Type().(*types.Named).TypeParams().At(0), + } + }, + Helper: ` + func unwrapOptional[T any](v optional.Option[T]) (T, error) { + out, _ := v.Get() + return out, nil + } + `, + } + }, +} + +func loadInterface(pkg, symbol string) *types.Interface { + name := loadObject(pkg, symbol) + if t, ok := name.(*types.TypeName); ok { + if t.Name() == symbol { + return t.Type().Underlying().(*types.Interface) //nolint:forcetypeassert + } + } + panic("could not find " + pkg + "." + symbol) +} + +func loadObject(pkgName, symbol string) types.Object { + pkgs, err := packages.Load(&packages.Config{ + Fset: fset, + Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedSyntax | + packages.NeedFiles | packages.NeedName, + }, pkgName) + if err != nil { + panic(err) + } + for _, pkg := range pkgs { + obj := pkg.Types.Scope().Lookup(symbol) + if obj != nil { + return obj + } + } + panic("could not find " + pkgName + "." + symbol) +} + +func implements(t types.Type, i *types.Interface) bool { + return types.Implements(t, i) || types.Implements(types.NewPointer(t), i) +} diff --git a/cmd/go2proto/transform/transform_test.go b/cmd/go2proto/transform/transform_test.go new file mode 100644 index 0000000000..620bd50814 --- /dev/null +++ b/cmd/go2proto/transform/transform_test.go @@ -0,0 +1,64 @@ +package transform + +import ( + "go/types" + "testing" + + "github.com/alecthomas/assert/v2" + "github.com/alecthomas/types/must" +) + +func TestTransform(t *testing.T) { + var ( + tctx = types.NewContext() + stringType = basicType("string") + bytesType = types.NewSlice(basicType("byte")) + urlType = loadObject("net/url", "URL").(*types.TypeName).Type() + optionalType = loadObject("github.com/alecthomas/types/optional", "Option").(*types.TypeName).Type() + optionalStringType = must.Get(types.Instantiate(tctx, optionalType, []types.Type{stringType}, false)) + optionalBytesType = must.Get(types.Instantiate(tctx, optionalType, []types.Type{bytesType}, false)) + ) + tests := []struct { + name string + input types.Type + output types.Type + expected string + ok bool + }{ + {"StringToBytes", + stringType, + bytesType, + `result.Map(result.Ok(input), stringToBytes)`, + true}, + {"MarshalBinary", + urlType, + bytesType, + `result.Map(result.Ok(input), marshalBinary)`, + true}, + {"ToOptional", + optionalStringType, + stringType, + "result.Map(result.Ok(input), unwrapOptional)", + true}, + {"OptionalStringToBytes", + optionalStringType, + bytesType, + "result.Map(result.Map(result.Ok(input), unwrapOptional), stringToBytes)", + true}, + {"OptionalBytesToString", + optionalBytesType, + stringType, + "result.Map(result.Map(result.Ok(input), unwrapOptional), bytesToString)", + true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + input := Var("input", test.input) + out, ok := Transform(input, test.output) + assert.Equal(t, test.ok, ok, "could not find transform from %s to %s", test.input, test.output) + if ok { + assert.Equal(t, test.expected, out.Expr()) + } + }) + } +} diff --git a/cmd/go2proto/transform/type.go b/cmd/go2proto/transform/type.go new file mode 100644 index 0000000000..5795c7ae05 --- /dev/null +++ b/cmd/go2proto/transform/type.go @@ -0,0 +1,61 @@ +package transform + +import ( + "reflect" + "slices" +) + +// Type is similar to reflect.Type but handles both go/types.Type and "synthetic" types that are created for the +// purposes of matching transformations. +type Type interface { + Kind() reflect.Kind + // Only relevant for map types. + Key() Type + // For types this is the element type. + Elem() Type + // Size of array. + Size() int + Methods() []Function + Fields() []Field +} + +type Function struct { + Name string + In []Type + Out []Type +} + +type Field struct { + Name string + Type Type +} + +// Identical returns true if a and b are identical types. +func Identical(a, b Type) bool { + if a.Kind() != b.Kind() { + return false + } + switch a.Kind() { + case reflect.Array: + return a.Size() == b.Size() && Identical(a.Elem(), b.Elem()) + + case reflect.Slice, reflect.Chan, reflect.Pointer: + return Identical(a.Elem(), b.Elem()) + + case reflect.Struct: + return slices.EqualFunc(a.Fields(), b.Fields(), func(af, bf Field) bool { + return af.Name == bf.Name && Identical(af.Type, bf.Type) + }) + + case reflect.Interface: + return slices.EqualFunc(a.Methods(), b.Methods(), func(am, bm Function) bool { + return am.Name == bm.Name && slices.EqualFunc(am.In, bm.In, Identical) && slices.EqualFunc(am.Out, bm.Out, Identical) + }) + + case reflect.Map: + return Identical(a.Key(), b.Key()) && Identical(a.Elem(), b.Elem()) + + default: + return true + } +} diff --git a/cmd/go2proto/transform/type_test.go b/cmd/go2proto/transform/type_test.go new file mode 100644 index 0000000000..b863316fe1 --- /dev/null +++ b/cmd/go2proto/transform/type_test.go @@ -0,0 +1,24 @@ +package transform + +import ( + "fmt" + "go/types" + "testing" +) + +func TestType(t *testing.T) { + t.Run("GoTypes", func(t *testing.T) { + obj := loadObject("github.com/block/ftl/cmd/go2proto/testdata", "Message").(*types.TypeName).Type() + fmt.Printf("%T\n", obj) + testType(t, FromGoTypes(obj)) + }) +} + +func testType(t *testing.T, typ Type) { + for _, method := range typ.Methods() { + fmt.Println(method.Name) + } + for _, field := range typ.Fields() { + fmt.Println("field", field.Name, field.Type) + } +} diff --git a/cmd/go2proto/transform/typesgo.go b/cmd/go2proto/transform/typesgo.go new file mode 100644 index 0000000000..680a64107f --- /dev/null +++ b/cmd/go2proto/transform/typesgo.go @@ -0,0 +1,180 @@ +package transform + +import ( + "fmt" + "go/types" + "reflect" +) + +// FromGoTypes creates an adapter from [go/types.Type] to [Type]. +func FromGoTypes(t types.Type) Type { + return &goType{t} +} + +var _ Type = (*goType)(nil) + +type goType struct { + t types.Type +} + +func (g *goType) Elem() Type { + switch t := g.t.(type) { + case *types.Pointer: + return &goType{t.Elem()} + case *types.Array: + return &goType{t.Elem()} + case *types.Slice: + return &goType{t.Elem()} + case *types.Map: + return &goType{t.Elem()} + case *types.Chan: + return &goType{t.Elem()} + default: + panic(fmt.Sprintf("type %T does not have an element type", g.t)) + } +} + +func (g *goType) Fields() []Field { + switch t := g.t.(type) { + case *types.Struct: + fields := make([]Field, t.NumFields()) + for i := range t.NumFields() { + field := t.Field(i) + fields[i] = Field{ + Name: field.Name(), + Type: &goType{field.Type()}, + } + } + return fields + case *types.Named: + if s, ok := t.Underlying().(*types.Struct); ok { + fields := make([]Field, s.NumFields()) + for i := range s.NumFields() { + field := s.Field(i) + fields[i] = Field{ + Name: field.Name(), + Type: &goType{field.Type()}, + } + } + return fields + } + return nil + default: + panic("type does not have fields") + } +} + +func (g *goType) Key() Type { + switch t := g.t.(type) { + case *types.Map: + return &goType{t.Key()} + default: + panic("type does not have a key type") + } +} + +func (g *goType) Kind() reflect.Kind { + return typesKind(g.t) +} + +func (g *goType) Methods() []Function { + return append(goTypeMethod(g.t), goTypeMethod(types.NewPointer(g.t))...) +} + +func (g *goType) String() string { return g.Kind().String() } + +func (g *goType) Size() int { + a, ok := g.t.(*types.Array) + if !ok { + panic("not an array") + } + return int(a.Len()) +} + +func typesKind(t types.Type) reflect.Kind { + switch t := t.(type) { + case *types.Struct: + return reflect.Struct + + case *types.Interface: + return reflect.Interface + + case *types.Array: + return reflect.Array + + case *types.Slice: + return reflect.Slice + + case *types.Chan: + return reflect.Chan + + case *types.Pointer: + return reflect.Pointer + + case *types.Map: + return reflect.Map + + case *types.Signature: + return reflect.Func + + case *types.Basic: + return typesToReflectKind[t.Kind()] + + case *types.Named: + return typesKind(t.Underlying()) + + default: + panic("unimplemented") + } +} + +func goTypeMethod(t types.Type) []Function { + methods := types.NewMethodSet(types.NewPointer(t)) + out := make([]Function, methods.Len()) + for i := range methods.Len() { + method := methods.At(i) + funcObj := method.Obj().(*types.Func) + sig := funcObj.Type().(*types.Signature) + + // Convert parameters + params := sig.Params() + inParams := make([]Type, params.Len()) + for j := range params.Len() { + inParams[j] = &goType{params.At(j).Type()} + } + + // Convert results + results := sig.Results() + outParams := make([]Type, results.Len()) + for j := range results.Len() { + outParams[j] = &goType{results.At(j).Type()} + } + + out[i] = Function{ + Name: funcObj.Name(), + In: inParams, + Out: outParams, + } + } + return out +} + +var typesToReflectKind = []reflect.Kind{ + types.Bool: reflect.Bool, + types.Int: reflect.Int, + types.Int8: reflect.Int8, + types.Int16: reflect.Int16, + types.Int32: reflect.Int32, + types.Int64: reflect.Int64, + types.Uint: reflect.Uint, + types.Uint8: reflect.Uint8, + types.Uint16: reflect.Uint16, + types.Uint32: reflect.Uint32, + types.Uint64: reflect.Uint64, + types.Float32: reflect.Float32, + types.Float64: reflect.Float64, + types.Complex64: reflect.Complex64, + types.Complex128: reflect.Complex128, + types.String: reflect.String, + types.UnsafePointer: reflect.UnsafePointer, +}