diff --git a/NewType.Generator/AliasCodeGenerator.cs b/NewType.Generator/AliasCodeGenerator.cs index ce4fefb..73a04f9 100644 --- a/NewType.Generator/AliasCodeGenerator.cs +++ b/NewType.Generator/AliasCodeGenerator.cs @@ -1,6 +1,4 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using System.Text; namespace newtype.generator; @@ -10,34 +8,12 @@ namespace newtype.generator; /// internal class AliasCodeGenerator { - // ReSharper disable once NotAccessedField.Local - private readonly Compilation _compilation; - private readonly AliasInfo _alias; + private readonly AliasModel _model; private readonly StringBuilder _sb = new(); - private readonly string _typeName; - private readonly string _aliasedTypeName; - private readonly string _aliasedTypeFullName; - private readonly string _namespace; - private readonly bool _isReadonly; - private readonly bool _isClass; - private readonly bool _isRecord; - private readonly bool _isRecordStruct; - - public AliasCodeGenerator(Compilation compilation, AliasInfo alias) + + public AliasCodeGenerator(AliasModel model) { - _compilation = compilation; - _alias = alias; - _typeName = alias.TypeSymbol.Name; - _aliasedTypeName = alias.AliasedType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); - _aliasedTypeFullName = alias.AliasedType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var ns = alias.TypeSymbol.ContainingNamespace; - _namespace = ns is {IsGlobalNamespace: false} ? ns.ToDisplayString() : ""; - _isReadonly = alias.TypeDeclaration.Modifiers.Any(SyntaxKind.ReadOnlyKeyword); - _isClass = alias.TypeDeclaration is ClassDeclarationSyntax - || (alias.TypeDeclaration is RecordDeclarationSyntax rds - && !rds.ClassOrStructKeyword.IsKind(SyntaxKind.StructKeyword)); - _isRecord = alias.TypeDeclaration is RecordDeclarationSyntax; - _isRecordStruct = _isRecord && !_isClass; + _model = model; } public string Generate() @@ -59,7 +35,7 @@ public string Generate() AppendStaticMembers(); AppendInstanceMembers(); AppendToString(); - if (!_isRecord) + if (!_model.IsRecord) AppendGetHashCode(); AppendTypeClose(); @@ -80,16 +56,16 @@ private void AppendHeader() private void AppendNamespaceOpen() { - if (!string.IsNullOrEmpty(_namespace)) + if (!string.IsNullOrEmpty(_model.Namespace)) { - _sb.AppendLine($"namespace {_namespace}"); + _sb.AppendLine($"namespace {_model.Namespace}"); _sb.AppendLine("{"); } } private void AppendNamespaceClose() { - if (!string.IsNullOrEmpty(_namespace)) + if (!string.IsNullOrEmpty(_model.Namespace)) { _sb.AppendLine("}"); } @@ -97,9 +73,9 @@ private void AppendNamespaceClose() private void AppendTypeOpen() { - var indent = string.IsNullOrEmpty(_namespace) ? "" : " "; - var readonlyMod = (_isReadonly && !_isClass) ? "readonly " : ""; - var accessMod = _alias.TypeSymbol.DeclaredAccessibility switch + var indent = string.IsNullOrEmpty(_model.Namespace) ? "" : " "; + var readonlyMod = (_model.IsReadonly && !_model.IsClass) ? "readonly " : ""; + var accessMod = _model.DeclaredAccessibility switch { Accessibility.Public => "public ", Accessibility.Internal => "internal ", @@ -113,44 +89,42 @@ private void AppendTypeOpen() // Build interface list var interfaces = new List { - $"global::System.IEquatable<{_typeName}>" + $"global::System.IEquatable<{_model.TypeName}>" }; - // Check if aliased type implements IComparable - if (ImplementsInterface(_alias.AliasedType, "System.IComparable`1")) + if (_model.ImplementsIComparable) { - interfaces.Add($"global::System.IComparable<{_typeName}>"); + interfaces.Add($"global::System.IComparable<{_model.TypeName}>"); } - // Check for IFormattable - if (ImplementsInterface(_alias.AliasedType, "System.IFormattable")) + if (_model.ImplementsIFormattable) { interfaces.Add("global::System.IFormattable"); } var interfaceList = string.Join(", ", interfaces); - var typeKeyword = (_isClass, _isRecord) switch + var typeKeyword = (_model.IsClass, _model.IsRecord) switch { (false, false) => "struct", (false, true) => "record struct", (true, false) => "class", (true, true) => "record class", }; - _sb.AppendLine($"{indent}{accessMod}{readonlyMod}partial {typeKeyword} {_typeName} : {interfaceList}"); + _sb.AppendLine($"{indent}{accessMod}{readonlyMod}partial {typeKeyword} {_model.TypeName} : {interfaceList}"); _sb.AppendLine($"{indent}{{"); } private void AppendTypeClose() { - var indent = string.IsNullOrEmpty(_namespace) ? "" : " "; + var indent = string.IsNullOrEmpty(_model.Namespace) ? "" : " "; _sb.AppendLine($"{indent}}}"); } private void AppendField() { var indent = GetMemberIndent(); - _sb.AppendLine($"{indent}private readonly {_aliasedTypeFullName} _value;"); + _sb.AppendLine($"{indent}private readonly {_model.AliasedTypeFullName} _value;"); _sb.AppendLine(); } @@ -159,14 +133,13 @@ private void AppendConstructors() var indent = GetMemberIndent(); // Constructor from aliased type - _sb.AppendLine($"{indent}/// Creates a new {_typeName} from a {_aliasedTypeName}."); + _sb.AppendLine($"{indent}/// Creates a new {_model.TypeName} from a {_model.AliasedTypeMinimalName}."); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public {_typeName}({_aliasedTypeFullName} value) => _value = value;"); + _sb.AppendLine($"{indent}public {_model.TypeName}({_model.AliasedTypeFullName} value) => _value = value;"); _sb.AppendLine(); // Forward constructors from the aliased type - var forwardable = GetForwardableConstructors(); - foreach (var ctor in forwardable) + foreach (var ctor in _model.ForwardedConstructors) { AppendForwardedConstructor(indent, ctor); } @@ -176,7 +149,7 @@ private void AppendValueProperty() { var indent = GetMemberIndent(); _sb.AppendLine($"{indent}/// Gets the underlying value."); - _sb.AppendLine($"{indent}public {_aliasedTypeFullName} Value"); + _sb.AppendLine($"{indent}public {_model.AliasedTypeFullName} Value"); _sb.AppendLine($"{indent}{{"); _sb.AppendLine($"{indent} [MethodImpl(MethodImplOptions.AggressiveInlining)]"); _sb.AppendLine($"{indent} get => _value;"); @@ -189,25 +162,24 @@ private void AppendImplicitOperators() var indent = GetMemberIndent(); // Implicit from aliased type to alias - _sb.AppendLine($"{indent}/// Implicitly converts from {_aliasedTypeName} to {_typeName}."); + _sb.AppendLine($"{indent}/// Implicitly converts from {_model.AliasedTypeMinimalName} to {_model.TypeName}."); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static implicit operator {_typeName}({_aliasedTypeFullName} value) => new {_typeName}(value);"); + _sb.AppendLine($"{indent}public static implicit operator {_model.TypeName}({_model.AliasedTypeFullName} value) => new {_model.TypeName}(value);"); _sb.AppendLine(); // Implicit from alias to aliased type - _sb.AppendLine($"{indent}/// Implicitly converts from {_typeName} to {_aliasedTypeName}."); + _sb.AppendLine($"{indent}/// Implicitly converts from {_model.TypeName} to {_model.AliasedTypeMinimalName}."); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static implicit operator {_aliasedTypeFullName}({_typeName} value) => value._value;"); + _sb.AppendLine($"{indent}public static implicit operator {_model.AliasedTypeFullName}({_model.TypeName} value) => value._value;"); _sb.AppendLine(); } private void AppendBinaryOperators() { var indent = GetMemberIndent(); - var operators = GetBinaryOperators(_alias.AliasedType); var emittedOps = new HashSet(); - foreach (var op in operators) + foreach (var op in _model.BinaryOperators) { var opSymbol = GetOperatorSymbol(op.Name); if (opSymbol == null) continue; @@ -217,59 +189,52 @@ private void AppendBinaryOperators() emittedOps.Add(op.Name); - // Get parameter types - var leftType = op.Parameters[0].Type; - var rightType = op.Parameters[1].Type; - var returnType = op.ReturnType; - // Generate operator with both sides as our type (if both params are the aliased type) - if (SymbolEqualityComparer.Default.Equals(leftType, _alias.AliasedType) && - SymbolEqualityComparer.Default.Equals(rightType, _alias.AliasedType)) + if (op.LeftIsAliasedType && op.RightIsAliasedType) { - var returnTypeStr = SymbolEqualityComparer.Default.Equals(returnType, _alias.AliasedType) - ? _typeName - : returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = op.ReturnIsAliasedType + ? _model.TypeName + : op.ReturnTypeFullName; _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_typeName} left, {_typeName} right) => left._value {opSymbol} right._value;"); + _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_model.TypeName} left, {_model.TypeName} right) => left._value {opSymbol} right._value;"); _sb.AppendLine(); // Also generate alias op T for cross-type interop - // e.g. Position + Velocity works via Position.+(Position, Vector3) with Velocity→Vector3 + // Note: T op alias is intentionally omitted here — when multiple aliases share + // the same underlying type T, emitting both directions creates ambiguous overloads + // (e.g. Position + Velocity would match both Vector3+Position and Velocity+Vector3). + // The implicit conversion to T already covers the T op alias case. _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_typeName} left, {_aliasedTypeFullName} right) => left._value {opSymbol} right;"); + _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_model.TypeName} left, {_model.AliasedTypeFullName} right) => left._value {opSymbol} right;"); _sb.AppendLine(); } - // Operator with aliased type on right - else if (SymbolEqualityComparer.Default.Equals(leftType, _alias.AliasedType) && - !SymbolEqualityComparer.Default.Equals(rightType, _alias.AliasedType)) + // Operator with aliased type on left only — also emit T op Alias + else if (op.LeftIsAliasedType && !op.RightIsAliasedType) { - var rightTypeStr = rightType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var returnTypeStr = SymbolEqualityComparer.Default.Equals(returnType, _alias.AliasedType) - ? _typeName - : returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = op.ReturnIsAliasedType + ? _model.TypeName + : op.ReturnTypeFullName; _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_typeName} left, {rightTypeStr} right) => left._value {opSymbol} right;"); + _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_model.TypeName} left, {op.RightTypeFullName} right) => left._value {opSymbol} right;"); _sb.AppendLine(); } - // Operator with aliased type on left - else if (!SymbolEqualityComparer.Default.Equals(leftType, _alias.AliasedType) && - SymbolEqualityComparer.Default.Equals(rightType, _alias.AliasedType)) + // Operator with aliased type on right only — also emit Alias op T + else if (!op.LeftIsAliasedType && op.RightIsAliasedType) { - var leftTypeStr = leftType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var returnTypeStr = SymbolEqualityComparer.Default.Equals(returnType, _alias.AliasedType) - ? _typeName - : returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = op.ReturnIsAliasedType + ? _model.TypeName + : op.ReturnTypeFullName; _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({leftTypeStr} left, {_typeName} right) => left {opSymbol} right._value;"); + _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({op.LeftTypeFullName} left, {_model.TypeName} right) => left {opSymbol} right._value;"); _sb.AppendLine(); } } // Emit built-in operators for primitive types not discovered via UserDefinedOperator - var builtInOps = GetBuiltInBinaryOperatorNames(_alias.AliasedType.SpecialType); + var builtInOps = GetBuiltInBinaryOperatorNames(_model.AliasedTypeSpecialType); foreach (var opName in builtInOps) { if (emittedOps.Contains(opName)) continue; @@ -280,26 +245,25 @@ private void AppendBinaryOperators() if (IsShiftOperator(opName)) { - // Shift operators: left is the type, right is always int _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {_typeName} operator {opSymbol}({_typeName} left, int right) => left._value {opSymbol} right;"); + _sb.AppendLine($"{indent}public static {_model.TypeName} operator {opSymbol}({_model.TypeName} left, int right) => left._value {opSymbol} right;"); _sb.AppendLine(); } else { // Alias op Alias _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {_typeName} operator {opSymbol}({_typeName} left, {_typeName} right) => left._value {opSymbol} right._value;"); + _sb.AppendLine($"{indent}public static {_model.TypeName} operator {opSymbol}({_model.TypeName} left, {_model.TypeName} right) => left._value {opSymbol} right._value;"); _sb.AppendLine(); // Alias op T _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {_typeName} operator {opSymbol}({_typeName} left, {_aliasedTypeFullName} right) => left._value {opSymbol} right;"); + _sb.AppendLine($"{indent}public static {_model.TypeName} operator {opSymbol}({_model.TypeName} left, {_model.AliasedTypeFullName} right) => left._value {opSymbol} right;"); _sb.AppendLine(); // T op Alias _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {_typeName} operator {opSymbol}({_aliasedTypeFullName} left, {_typeName} right) => left {opSymbol} right._value;"); + _sb.AppendLine($"{indent}public static {_model.TypeName} operator {opSymbol}({_model.AliasedTypeFullName} left, {_model.TypeName} right) => left {opSymbol} right._value;"); _sb.AppendLine(); } } @@ -308,28 +272,26 @@ private void AppendBinaryOperators() private void AppendUnaryOperators() { var indent = GetMemberIndent(); - var operators = GetUnaryOperators(_alias.AliasedType); var emittedOps = new HashSet(); - foreach (var op in operators) + foreach (var op in _model.UnaryOperators) { var opSymbol = GetOperatorSymbol(op.Name); if (opSymbol == null) continue; emittedOps.Add(op.Name); - var returnType = op.ReturnType; - var returnTypeStr = SymbolEqualityComparer.Default.Equals(returnType, _alias.AliasedType) - ? _typeName - : returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = op.ReturnIsAliasedType + ? _model.TypeName + : op.ReturnTypeFullName; _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_typeName} value) => {opSymbol}value._value;"); + _sb.AppendLine($"{indent}public static {returnTypeStr} operator {opSymbol}({_model.TypeName} value) => {opSymbol}value._value;"); _sb.AppendLine(); } // Emit built-in unary operators for primitive types - var builtInOps = GetBuiltInUnaryOperatorNames(_alias.AliasedType.SpecialType); + var builtInOps = GetBuiltInUnaryOperatorNames(_model.AliasedTypeSpecialType); foreach (var opName in builtInOps) { if (emittedOps.Contains(opName)) continue; @@ -338,7 +300,7 @@ private void AppendUnaryOperators() if (opSymbol == null) continue; _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static {_typeName} operator {opSymbol}({_typeName} value) => {opSymbol}value._value;"); + _sb.AppendLine($"{indent}public static {_model.TypeName} operator {opSymbol}({_model.TypeName} value) => {opSymbol}value._value;"); _sb.AppendLine(); } } @@ -348,7 +310,15 @@ private void AppendComparisonOperators() var indent = GetMemberIndent(); // Check if the type already has comparison operators (user-defined) - var hasLessThan = GetBinaryOperators(_alias.AliasedType).Any(o => o.Name == "op_LessThan"); + var hasLessThan = false; + foreach (var op in _model.BinaryOperators) + { + if (op.Name == "op_LessThan") + { + hasLessThan = true; + break; + } + } if (hasLessThan) { @@ -356,24 +326,24 @@ private void AppendComparisonOperators() return; } - // For primitives with built-in comparison, use direct operators (avoids CompareTo overhead) - if (HasBuiltInComparisonOperators(_alias.AliasedType.SpecialType)) + // For primitives with built-in comparison, use direct operators + if (HasBuiltInComparisonOperators(_model.AliasedTypeSpecialType)) { string[] ops = ["<", ">", "<=", ">="]; foreach (var op in ops) { _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator {op}({_typeName} left, {_typeName} right) => left._value {op} right._value;"); + _sb.AppendLine($"{indent}public static bool operator {op}({_model.TypeName} left, {_model.TypeName} right) => left._value {op} right._value;"); _sb.AppendLine(); } return; } - // Fallback: use IComparable.CompareTo for types without native comparison - if (ImplementsInterface(_alias.AliasedType, "System.IComparable`1")) + // Fallback: use IComparable.CompareTo + if (_model.ImplementsIComparable) { - var isRefType = !_alias.AliasedType.IsValueType; + var isRefType = !_model.AliasedTypeIsValueType; string CompareExpr(string op) => isRefType ? $"(left._value is null ? (right._value is null ? 0 : -1) : left._value.CompareTo(right._value)) {op} 0" @@ -383,7 +353,7 @@ string CompareExpr(string op) => isRefType foreach (var op in ops) { _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator {op}({_typeName} left, {_typeName} right) => {CompareExpr(op)};"); + _sb.AppendLine($"{indent}public static bool operator {op}({_model.TypeName} left, {_model.TypeName} right) => {CompareExpr(op)};"); _sb.AppendLine(); } } @@ -392,15 +362,15 @@ string CompareExpr(string op) => isRefType private void AppendEqualityMembers() { var indent = GetMemberIndent(); - var isRefType = !_alias.AliasedType.IsValueType; - var nullableParam = _isClass ? "?" : ""; + var isRefType = !_model.AliasedTypeIsValueType; + var nullableParam = _model.IsClass ? "?" : ""; // records synthesize Equals, ==, != — skip to avoid CS0111 - if (!_isRecord) + if (!_model.IsRecord) { // IEquatable.Equals string equalsExpr; - if (_isClass) + if (_model.IsClass) equalsExpr = isRefType ? "other is not null && object.Equals(_value, other._value)" : "other is not null && _value.Equals(other._value)"; @@ -410,58 +380,53 @@ private void AppendEqualityMembers() : "_value.Equals(other._value)"; _sb.AppendLine($"{indent}/// "); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public bool Equals({_typeName}{nullableParam} other) => {equalsExpr};"); + _sb.AppendLine($"{indent}public bool Equals({_model.TypeName}{nullableParam} other) => {equalsExpr};"); _sb.AppendLine(); // Object.Equals override _sb.AppendLine($"{indent}/// "); - _sb.AppendLine($"{indent}public override bool Equals(object? obj) => obj is {_typeName} other && Equals(other);"); + _sb.AppendLine($"{indent}public override bool Equals(object? obj) => obj is {_model.TypeName} other && Equals(other);"); _sb.AppendLine(); - // == and != operators: delegate directly to the underlying type's operators - // when available, to preserve exact semantics (e.g. NaN handling for floats) - // and allow the JIT to generate identical codegen. - var hasNativeEquality = HasNativeEqualityOperator(_alias.AliasedType); - - if (_isClass) + if (_model.IsClass) { // Class types need null-safe equality operators _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator ==({_typeName}? left, {_typeName}? right) => ReferenceEquals(left, right) || (left is not null && left.Equals(right));"); + _sb.AppendLine($"{indent}public static bool operator ==({_model.TypeName}? left, {_model.TypeName}? right) => ReferenceEquals(left, right) || (left is not null && left.Equals(right));"); _sb.AppendLine(); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator !=({_typeName}? left, {_typeName}? right) => !(left == right);"); + _sb.AppendLine($"{indent}public static bool operator !=({_model.TypeName}? left, {_model.TypeName}? right) => !(left == right);"); _sb.AppendLine(); } - else if (hasNativeEquality) + else if (_model.HasNativeEqualityOperator) { _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator ==({_typeName} left, {_typeName} right) => left._value == right._value;"); + _sb.AppendLine($"{indent}public static bool operator ==({_model.TypeName} left, {_model.TypeName} right) => left._value == right._value;"); _sb.AppendLine(); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator !=({_typeName} left, {_typeName} right) => left._value != right._value;"); + _sb.AppendLine($"{indent}public static bool operator !=({_model.TypeName} left, {_model.TypeName} right) => left._value != right._value;"); _sb.AppendLine(); } else { - // Fallback: route through Equals (for types without native == operator) + // Fallback: route through Equals _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator ==({_typeName} left, {_typeName} right) => left.Equals(right);"); + _sb.AppendLine($"{indent}public static bool operator ==({_model.TypeName} left, {_model.TypeName} right) => left.Equals(right);"); _sb.AppendLine(); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public static bool operator !=({_typeName} left, {_typeName} right) => !left.Equals(right);"); + _sb.AppendLine($"{indent}public static bool operator !=({_model.TypeName} left, {_model.TypeName} right) => !left.Equals(right);"); _sb.AppendLine(); } } // IComparable.CompareTo if applicable - if (ImplementsInterface(_alias.AliasedType, "System.IComparable`1")) + if (_model.ImplementsIComparable) { string compareExpr; - if (_isClass) + if (_model.IsClass) compareExpr = isRefType ? "other is null ? 1 : (_value is null ? (other._value is null ? 0 : -1) : _value.CompareTo(other._value))" : "other is null ? 1 : _value.CompareTo(other._value)"; @@ -471,7 +436,7 @@ private void AppendEqualityMembers() : "_value.CompareTo(other._value)"; _sb.AppendLine($"{indent}/// "); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public int CompareTo({_typeName}{nullableParam} other) => {compareExpr};"); + _sb.AppendLine($"{indent}public int CompareTo({_model.TypeName}{nullableParam} other) => {compareExpr};"); _sb.AppendLine(); } } @@ -479,55 +444,36 @@ private void AppendEqualityMembers() private void AppendStaticMembers() { var indent = GetMemberIndent(); - var aliasedType = _alias.AliasedType; - - // Get all public static properties and fields - var staticMembers = aliasedType.GetMembers() - .Where(m => m.IsStatic && m.DeclaredAccessibility == Accessibility.Public) - .Where(m => m is IPropertySymbol or IFieldSymbol) - .Where(m => !m.Name.StartsWith("op_")) // Skip operators - .ToList(); - if (staticMembers.Count == 0) return; + if (!_model.HasStaticMemberCandidates) return; _sb.AppendLine($"{indent}#region Static Members"); _sb.AppendLine(); - foreach (var member in staticMembers) + foreach (var member in _model.StaticMembers) { - if (member is IPropertySymbol prop) + var returnTypeStr = member.TypeIsAliasedType + ? _model.TypeName + : member.TypeFullName; + + var valueExpr = member.TypeIsAliasedType + ? $"new {_model.TypeName}({_model.AliasedTypeFullName}.{member.Name})" + : $"{_model.AliasedTypeFullName}.{member.Name}"; + + if (member.IsProperty) { - // Check if the property returns the aliased type - var returnTypeStr = SymbolEqualityComparer.Default.Equals(prop.Type, _alias.AliasedType) - ? _typeName - : prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - - var needsConversion = SymbolEqualityComparer.Default.Equals(prop.Type, _alias.AliasedType); - var valueExpr = needsConversion - ? $"new {_typeName}({_aliasedTypeFullName}.{prop.Name})" - : $"{_aliasedTypeFullName}.{prop.Name}"; - - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName}.{prop.Name}."); - _sb.AppendLine($"{indent}public static {returnTypeStr} {prop.Name}"); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName}.{member.Name}."); + _sb.AppendLine($"{indent}public static {returnTypeStr} {member.Name}"); _sb.AppendLine($"{indent}{{"); _sb.AppendLine($"{indent} [MethodImpl(MethodImplOptions.AggressiveInlining)]"); _sb.AppendLine($"{indent} get => {valueExpr};"); _sb.AppendLine($"{indent}}}"); _sb.AppendLine(); } - else if (member is IFieldSymbol {IsReadOnly: true} field) + else if (member.IsReadonlyField) { - var returnTypeStr = SymbolEqualityComparer.Default.Equals(field.Type, _alias.AliasedType) - ? _typeName - : field.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - - var needsConversion = SymbolEqualityComparer.Default.Equals(field.Type, _alias.AliasedType); - var valueExpr = needsConversion - ? $"new {_typeName}({_aliasedTypeFullName}.{field.Name})" - : $"{_aliasedTypeFullName}.{field.Name}"; - - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName}.{field.Name}."); - _sb.AppendLine($"{indent}public static {returnTypeStr} {field.Name} => {valueExpr};"); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName}.{member.Name}."); + _sb.AppendLine($"{indent}public static {returnTypeStr} {member.Name} => {valueExpr};"); _sb.AppendLine(); } } @@ -539,47 +485,25 @@ private void AppendStaticMembers() private void AppendInstanceMembers() { var indent = GetMemberIndent(); - var aliasedType = _alias.AliasedType; - - // Get instance fields (e.g. Vector3.X, Y, Z are fields, not properties) - var instanceFields = aliasedType.GetMembers() - .OfType() - .Where(f => !f.IsStatic && f is {IsImplicitlyDeclared: false, DeclaredAccessibility: Accessibility.Public}) - .ToList(); - - // Get instance properties - var instanceProps = aliasedType.GetMembers() - .OfType() - .Where(p => !p.IsStatic && p.DeclaredAccessibility == Accessibility.Public) - .Where(p => !p.IsIndexer) // Skip indexers - .ToList(); - - // Get instance methods (excluding special ones) - var instanceMethods = aliasedType.GetMembers() - .OfType() - .Where(m => !m.IsStatic && m.DeclaredAccessibility == Accessibility.Public) - .Where(m => m.MethodKind == MethodKind.Ordinary) - .Where(m => m.CanBeReferencedByName) // Skip compiler-synthesized names like $ - .Where(m => !m.Name.StartsWith("get_") && !m.Name.StartsWith("set_")) - .Where(m => m.Name != "GetHashCode" && m.Name != "Equals" && m.Name != "ToString" && m.Name != "CompareTo") - .ToList(); - - if (instanceFields.Count == 0 && instanceProps.Count == 0 && instanceMethods.Count == 0) return; + + if (_model.InstanceFields.Length == 0 && _model.InstanceProperties.Length == 0 && _model.InstanceMethods.Length == 0) + return; _sb.AppendLine($"{indent}#region Instance Members"); _sb.AppendLine(); // Generate field forwarders (exposed as readonly properties) - foreach (var field in instanceFields) + foreach (var field in _model.InstanceFields) { - var returnTypeStr = SymbolEqualityComparer.Default.Equals(field.Type, _alias.AliasedType) - ? _typeName - : field.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = field.TypeIsAliasedType + ? _model.TypeName + : field.TypeFullName; - var needsConversion = SymbolEqualityComparer.Default.Equals(field.Type, _alias.AliasedType); - var valueExpr = needsConversion ? $"new {_typeName}(_value.{field.Name})" : $"_value.{field.Name}"; + var valueExpr = field.TypeIsAliasedType + ? $"new {_model.TypeName}(_value.{field.Name})" + : $"_value.{field.Name}"; - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName}.{field.Name}."); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName}.{field.Name}."); _sb.AppendLine($"{indent}public {returnTypeStr} {field.Name}"); _sb.AppendLine($"{indent}{{"); _sb.AppendLine($"{indent} [MethodImpl(MethodImplOptions.AggressiveInlining)]"); @@ -589,20 +513,21 @@ private void AppendInstanceMembers() } // Generate property forwarders - foreach (var prop in instanceProps) + foreach (var prop in _model.InstanceProperties) { - var returnTypeStr = SymbolEqualityComparer.Default.Equals(prop.Type, _alias.AliasedType) - ? _typeName - : prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeStr = prop.TypeIsAliasedType + ? _model.TypeName + : prop.TypeFullName; - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName}.{prop.Name}."); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName}.{prop.Name}."); _sb.AppendLine($"{indent}public {returnTypeStr} {prop.Name}"); _sb.AppendLine($"{indent}{{"); - if (prop.GetMethod != null) + if (prop.HasGetter) { - var needsConversion = SymbolEqualityComparer.Default.Equals(prop.Type, _alias.AliasedType); - var valueExpr = needsConversion ? $"new {_typeName}(_value.{prop.Name})" : $"_value.{prop.Name}"; + var valueExpr = prop.TypeIsAliasedType + ? $"new {_model.TypeName}(_value.{prop.Name})" + : $"_value.{prop.Name}"; _sb.AppendLine($"{indent} [MethodImpl(MethodImplOptions.AggressiveInlining)]"); _sb.AppendLine($"{indent} get => {valueExpr};"); } @@ -612,24 +537,20 @@ private void AppendInstanceMembers() } // Generate method forwarders - foreach (var method in instanceMethods) + foreach (var method in _model.InstanceMethods) { - var skipReturnWrapping = _alias.AliasedType.IsValueType && - _alias.AliasedType.SpecialType != SpecialType.None; var returnTypeStr = method.ReturnsVoid ? "void" - : (SymbolEqualityComparer.Default.Equals(method.ReturnType, _alias.AliasedType) && !skipReturnWrapping) - ? _typeName - : method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + : (method.ReturnIsAliasedType && !method.SkipReturnWrapping) + ? _model.TypeName + : method.ReturnTypeFullName; - var parameters = string.Join(", ", method.Parameters.Select(p => + var parameters = string.Join(", ", method.Parameters.Array.Select(p => { - var isAliasedType = SymbolEqualityComparer.Default.Equals(p.Type, _alias.AliasedType); // out parameters of the aliased type must keep the underlying type - // to avoid CS0192 (cannot pass readonly field as ref/out) - var paramType = (isAliasedType && p.RefKind != RefKind.Out) - ? _typeName - : p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var paramType = (p.IsAliasedType && p.RefKind != RefKind.Out) + ? _model.TypeName + : p.TypeFullName; var refKind = p.RefKind switch { RefKind.Ref => "ref ", @@ -640,7 +561,7 @@ private void AppendInstanceMembers() return $"{refKind}{paramType} {p.Name}"; })); - var arguments = string.Join(", ", method.Parameters.Select(p => + var arguments = string.Join(", ", method.Parameters.Array.Select(p => { var refKind = p.RefKind switch { @@ -649,21 +570,20 @@ private void AppendInstanceMembers() RefKind.In => "in ", _ => "" }; - var isAliasedType = SymbolEqualityComparer.Default.Equals(p.Type, _alias.AliasedType); // out parameters of the aliased type are passed directly (no ._value conversion) - var arg = (isAliasedType && p.RefKind != RefKind.Out) ? $"{p.Name}._value" : p.Name; + var arg = (p.IsAliasedType && p.RefKind != RefKind.Out) ? $"{p.Name}._value" : p.Name; return $"{refKind}{arg}"; })); - var needsReturnConversion = !method.ReturnsVoid && !skipReturnWrapping && - SymbolEqualityComparer.Default.Equals(method.ReturnType, _alias.AliasedType); + var needsReturnConversion = !method.ReturnsVoid && !method.SkipReturnWrapping && + method.ReturnIsAliasedType; var returnExpr = method.ReturnsVoid ? $"_value.{method.Name}({arguments})" : needsReturnConversion - ? $"new {_typeName}(_value.{method.Name}({arguments}))" + ? $"new {_model.TypeName}(_value.{method.Name}({arguments}))" : $"_value.{method.Name}({arguments})"; - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName}.{method.Name}."); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName}.{method.Name}."); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); _sb.AppendLine( @@ -682,15 +602,14 @@ private void AppendInstanceMembers() private void AppendToString() { var indent = GetMemberIndent(); - var isRefType = !_alias.AliasedType.IsValueType; + var isRefType = !_model.AliasedTypeIsValueType; var toStringExpr = isRefType ? "_value?.ToString() ?? \"\"" : "_value.ToString()"; _sb.AppendLine($"{indent}/// "); _sb.AppendLine($"{indent}public override string ToString() => {toStringExpr};"); _sb.AppendLine(); - // Add IFormattable.ToString if the aliased type implements it - if (ImplementsInterface(_alias.AliasedType, "System.IFormattable")) + if (_model.ImplementsIFormattable) { _sb.AppendLine($"{indent}/// "); _sb.AppendLine($"{indent}public string ToString(string? format, IFormatProvider? formatProvider) => _value.ToString(format, formatProvider);"); @@ -701,7 +620,7 @@ private void AppendToString() private void AppendGetHashCode() { var indent = GetMemberIndent(); - var isRefType = !_alias.AliasedType.IsValueType; + var isRefType = !_model.AliasedTypeIsValueType; var hashExpr = isRefType ? "_value?.GetHashCode() ?? 0" : "_value.GetHashCode()"; _sb.AppendLine($"{indent}/// "); @@ -710,57 +629,21 @@ private void AppendGetHashCode() _sb.AppendLine(); } - /// - /// Discovers public instance constructors on the aliased type that should be forwarded. - /// Filters out: parameterless, single-param-of-aliased-type (already generated), - /// implicitly declared, and any that conflict with user-defined constructors on the partial struct. - /// - private IReadOnlyList GetForwardableConstructors() - { - var userSignatures = GetUserDefinedConstructorSignatures(); - - return _alias.AliasedType.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.Constructor - && !m.IsStatic - && m.DeclaredAccessibility == Accessibility.Public - && !m.IsImplicitlyDeclared - && m.Parameters.Length > 0) - .Where(m => - { - // Skip copy constructor (single param of the aliased type itself) - if (m.Parameters.Length == 1 && - SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, _alias.AliasedType)) - return false; - - // Skip constructors with pointer parameters (require unsafe context) - if (m.Parameters.Any(p => p.Type.TypeKind == TypeKind.Pointer)) - return false; - - // Skip if user already defined a constructor with the same signature - var sig = GetConstructorSignature(m); - return !userSignatures.Contains(sig); - }) - .ToList(); - } - - private void AppendForwardedConstructor(string indent, IMethodSymbol ctor) + private void AppendForwardedConstructor(string indent, ConstructorInfo ctor) { var parameters = FormatConstructorParameters(ctor); var arguments = FormatConstructorArguments(ctor); - _sb.AppendLine($"{indent}/// Forwards {_aliasedTypeName} constructor."); + _sb.AppendLine($"{indent}/// Forwards {_model.AliasedTypeMinimalName} constructor."); _sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]"); - _sb.AppendLine($"{indent}public {_typeName}({parameters}) => _value = new {_aliasedTypeFullName}({arguments});"); + _sb.AppendLine($"{indent}public {_model.TypeName}({parameters}) => _value = new {_model.AliasedTypeFullName}({arguments});"); _sb.AppendLine(); } - private string FormatConstructorParameters(IMethodSymbol ctor) + private static string FormatConstructorParameters(ConstructorInfo ctor) { - return string.Join(", ", ctor.Parameters.Select(p => + return string.Join(", ", ctor.Parameters.Array.Select(p => { - var typeStr = p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var refKind = p.RefKind switch { RefKind.Ref => "ref ", @@ -771,17 +654,17 @@ private string FormatConstructorParameters(IMethodSymbol ctor) var paramsKeyword = p.IsParams ? "params " : ""; - var defaultValue = p.HasExplicitDefaultValue - ? $" = {FormatDefaultValue(p)}" + var defaultValue = p.DefaultValueLiteral != null + ? $" = {p.DefaultValueLiteral}" : ""; - return $"{paramsKeyword}{refKind}{typeStr} {p.Name}{defaultValue}"; + return $"{paramsKeyword}{refKind}{p.TypeFullName} {p.Name}{defaultValue}"; })); } - private static string FormatConstructorArguments(IMethodSymbol ctor) + private static string FormatConstructorArguments(ConstructorInfo ctor) { - return string.Join(", ", ctor.Parameters.Select(p => + return string.Join(", ", ctor.Parameters.Array.Select(p => { var refKind = p.RefKind switch { @@ -794,77 +677,7 @@ private static string FormatConstructorArguments(IMethodSymbol ctor) })); } - private static string FormatDefaultValue(IParameterSymbol param) - { - var value = param.ExplicitDefaultValue; - - if (value is null) - return "default"; - - if (value is string s) - return $"\"{s.Replace("\\", "\\\\").Replace("\"", "\\\"")}\""; - - if (value is char c) - return $"'{c}'"; - - if (value is bool b) - return b ? "true" : "false"; - - if (value is float f) - return f.ToString("R") + "f"; - - if (value is double d) - return d.ToString("R") + "d"; - - if (value is decimal m) - return m.ToString() + "m"; - - return value.ToString(); - } - - /// - /// Gets signatures of constructors the user has explicitly defined on the partial struct, - /// so we can avoid generating conflicting constructors. - /// - private HashSet GetUserDefinedConstructorSignatures() - { - var signatures = new HashSet(); - foreach (var member in _alias.TypeSymbol.GetMembers()) - { - if (member is IMethodSymbol {MethodKind: MethodKind.Constructor, IsImplicitlyDeclared: false} ctor) - { - signatures.Add(GetConstructorSignature(ctor)); - } - } - return signatures; - } - - /// - /// Creates a normalized signature string from a constructor's parameter types for comparison. - /// - private static string GetConstructorSignature(IMethodSymbol ctor) - { - return string.Join(",", ctor.Parameters.Select(p => - p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))); - } - - private string GetMemberIndent() => string.IsNullOrEmpty(_namespace) ? " " : " "; - - private IEnumerable GetBinaryOperators(ITypeSymbol type) - { - return type.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.UserDefinedOperator) - .Where(m => m.Parameters.Length == 2); - } - - private IEnumerable GetUnaryOperators(ITypeSymbol type) - { - return type.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.UserDefinedOperator) - .Where(m => m.Parameters.Length == 1); - } + private string GetMemberIndent() => string.IsNullOrEmpty(_model.Namespace) ? " " : " "; private static string? GetOperatorSymbol(string operatorName) { @@ -898,7 +711,6 @@ private IEnumerable GetUnaryOperators(ITypeSymbol type) /// /// Returns the built-in binary operator metadata names for a primitive SpecialType. - /// Primitives use compiler-intrinsic operators not visible as UserDefinedOperator members. /// private static IReadOnlyList GetBuiltInBinaryOperatorNames(SpecialType specialType) { @@ -958,40 +770,11 @@ private static IReadOnlyList GetBuiltInUnaryOperatorNames(SpecialType sp } } - /// - /// Returns true if the given operator name is a shift operator (where the right operand is always int). - /// private static bool IsShiftOperator(string opName) { return opName == "op_LeftShift" || opName == "op_RightShift"; } - /// - /// Returns true if the type has a native == operator (user-defined or built-in primitive). - /// When true, operator == should delegate directly to _value == other._value - /// to preserve exact semantics (e.g. NaN handling) and generate optimal codegen. - /// - private bool HasNativeEqualityOperator(ITypeSymbol type) - { - // User-defined operator == - if (GetBinaryOperators(type).Any(o => o.Name == "op_Equality")) - return true; - - // Built-in == for primitives - return type.SpecialType is - SpecialType.System_Boolean or - SpecialType.System_Byte or SpecialType.System_SByte or - SpecialType.System_Int16 or SpecialType.System_UInt16 or - SpecialType.System_Int32 or SpecialType.System_UInt32 or - SpecialType.System_Int64 or SpecialType.System_UInt64 or - SpecialType.System_Single or SpecialType.System_Double or - SpecialType.System_Decimal or SpecialType.System_Char or - SpecialType.System_String; - } - - /// - /// Returns true if the primitive type has built-in comparison operators (<, >, etc.). - /// private static bool HasBuiltInComparisonOperators(SpecialType specialType) { return specialType is @@ -1002,17 +785,4 @@ SpecialType.System_Int64 or SpecialType.System_UInt64 or SpecialType.System_Single or SpecialType.System_Double or SpecialType.System_Decimal or SpecialType.System_Char; } - - private bool ImplementsInterface(ITypeSymbol type, string interfaceFullName) - { - // Handle generic interfaces like IComparable`1 - if (interfaceFullName.Contains("`")) - { - var baseName = interfaceFullName.Split('`')[0]; - return type.AllInterfaces.Any(i => - i.OriginalDefinition.ToDisplayString().StartsWith(baseName)); - } - - return type.AllInterfaces.Any(i => i.ToDisplayString() == interfaceFullName); - } -} \ No newline at end of file +} diff --git a/NewType.Generator/AliasGenerator.cs b/NewType.Generator/AliasGenerator.cs index c398056..8bed475 100644 --- a/NewType.Generator/AliasGenerator.cs +++ b/NewType.Generator/AliasGenerator.cs @@ -1,5 +1,4 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; using System.Text; @@ -8,6 +7,7 @@ namespace newtype.generator; /// /// Incremental source generator that creates type alias implementations. +/// Uses ForAttributeWithMetadataName for efficient attribute-based incremental generation. /// [Generator(LanguageNames.CSharp)] public class AliasGenerator : IIncrementalGenerator @@ -17,111 +17,65 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Register the attribute source context.RegisterPostInitializationOutput(ctx => { ctx.AddSource("newtypeAttribute.g.cs", SourceText.From(NewtypeAttributeSource.Source, Encoding.UTF8)); }); - // Find all type declarations with our attribute - var aliasDeclarations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (node, _) => IsCandidateType(node), - transform: static (ctx, _) => GetAliasInfo(ctx)) - .Where(static info => info is not null) - .Select(static (info, _) => info!.Value); - - // Combine with compilation - var compilationAndAliases = context.CompilationProvider.Combine(aliasDeclarations.Collect()); + // Pipeline for generic [newtype] attribute + var genericPipeline = context.SyntaxProvider + .ForAttributeWithMetadataName( + "newtype.newtypeAttribute`1", + predicate: static (node, _) => node is TypeDeclarationSyntax, + transform: static (ctx, _) => ExtractGenericModel(ctx)) + .Where(static model => model is not null) + .Select(static (model, _) => model!.Value); + + context.RegisterSourceOutput(genericPipeline, static (spc, model) => GenerateAliasCode(spc, model)); + + // Pipeline for non-generic [newtype(typeof(T))] attribute + var nonGenericPipeline = context.SyntaxProvider + .ForAttributeWithMetadataName( + "newtype.newtypeAttribute", + predicate: static (node, _) => node is TypeDeclarationSyntax, + transform: static (ctx, _) => ExtractNonGenericModel(ctx)) + .Where(static model => model is not null) + .Select(static (model, _) => model!.Value); + + context.RegisterSourceOutput(nonGenericPipeline, static (spc, model) => GenerateAliasCode(spc, model)); + } - // Generate the source - context.RegisterSourceOutput(compilationAndAliases, static (spc, source) => + private static AliasModel? ExtractGenericModel(GeneratorAttributeSyntaxContext context) + { + foreach (var attributeData in context.Attributes) { - var (compilation, aliases) = source; - foreach (var alias in aliases) + var attributeClass = attributeData.AttributeClass; + if (attributeClass is {IsGenericType: true} && + attributeClass.TypeArguments.Length == 1) { - GenerateAliasCode(spc, compilation, alias); + var aliasedType = attributeClass.TypeArguments[0]; + return AliasModelExtractor.Extract(context, aliasedType); } - }); - } - - private static bool IsCandidateType(SyntaxNode node) - { - if (node is StructDeclarationSyntax {AttributeLists.Count: > 0} structDecl) - return structDecl.Modifiers.Any(SyntaxKind.PartialKeyword); - - if (node is ClassDeclarationSyntax {AttributeLists.Count: > 0} classDecl) - return classDecl.Modifiers.Any(SyntaxKind.PartialKeyword); - - if (node is RecordDeclarationSyntax {AttributeLists.Count: > 0} recordDecl) - return recordDecl.Modifiers.Any(SyntaxKind.PartialKeyword); - - return false; + } + return null; } - private static AliasInfo? GetAliasInfo(GeneratorSyntaxContext context) + private static AliasModel? ExtractNonGenericModel(GeneratorAttributeSyntaxContext context) { - var typeDecl = (TypeDeclarationSyntax) context.Node; - var semanticModel = context.SemanticModel; - - foreach (var attributeList in typeDecl.AttributeLists) + foreach (var attributeData in context.Attributes) { - foreach (var attribute in attributeList.Attributes) + if (attributeData.ConstructorArguments.Length > 0 && + attributeData.ConstructorArguments[0].Value is ITypeSymbol aliasedType) { - var symbolInfo = semanticModel.GetSymbolInfo(attribute); - if (symbolInfo.Symbol is not IMethodSymbol attributeConstructor) - continue; - - var attributeType = attributeConstructor.ContainingType; - var fullName = attributeType.ToDisplayString(); - - // Check for generic Alias - if (attributeType.IsGenericType && - attributeType.OriginalDefinition.ToDisplayString() == "newtype.newtypeAttribute") - { - var aliasedType = attributeType.TypeArguments[0]; - var typeSymbol = semanticModel.GetDeclaredSymbol(typeDecl); - if (typeSymbol is null) continue; - - return new AliasInfo( - typeDecl, - typeSymbol, - aliasedType); - } - - // Check for non-generic Alias(typeof(T)) - if (fullName == "newtype.newtypeAttribute") - { - var attributeData = semanticModel.GetDeclaredSymbol(typeDecl)? - .GetAttributes() - .FirstOrDefault(ad => ad.AttributeClass?.ToDisplayString() == "newtype.newtypeAttribute"); - - if (attributeData?.ConstructorArguments.Length > 0 && - attributeData.ConstructorArguments[0].Value is ITypeSymbol aliasedType) - { - var typeSymbol = semanticModel.GetDeclaredSymbol(typeDecl); - if (typeSymbol is null) continue; - - return new AliasInfo( - typeDecl, - typeSymbol, - aliasedType); - } - } + return AliasModelExtractor.Extract(context, aliasedType); } } - return null; } private static void GenerateAliasCode( SourceProductionContext context, - Compilation compilation, - AliasInfo alias) + AliasModel model) { - var generator = new AliasCodeGenerator(compilation, alias); + var generator = new AliasCodeGenerator(model); var source = generator.Generate(); - var fileName = $"{alias.TypeSymbol.ToDisplayString().Replace(".", "_").Replace("<", "_").Replace(">", "_")}.g.cs"; + var fileName = $"{model.TypeDisplayString.Replace(".", "_").Replace("<", "_").Replace(">", "_")}.g.cs"; context.AddSource(fileName, SourceText.From(source, Encoding.UTF8)); } } - -internal readonly record struct AliasInfo( - TypeDeclarationSyntax TypeDeclaration, - INamedTypeSymbol TypeSymbol, - ITypeSymbol AliasedType); \ No newline at end of file diff --git a/NewType.Generator/AliasModel.cs b/NewType.Generator/AliasModel.cs new file mode 100644 index 0000000..3c63bcf --- /dev/null +++ b/NewType.Generator/AliasModel.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace newtype.generator; + +/// +/// Fully-extracted, equatable model representing a newtype alias. +/// Contains only strings, bools, plain enums, and EquatableArrays — no Roslyn symbols. +/// +internal readonly record struct AliasModel( + // Type being declared + string TypeName, + string Namespace, + Accessibility DeclaredAccessibility, + bool IsReadonly, + bool IsClass, + bool IsRecord, + bool IsRecordStruct, + + // Aliased type + string AliasedTypeFullName, + string AliasedTypeMinimalName, + SpecialType AliasedTypeSpecialType, + bool AliasedTypeIsValueType, + + // Interface flags + bool ImplementsIComparable, + bool ImplementsIFormattable, + bool HasNativeEqualityOperator, + + // Pre-computed file name + string TypeDisplayString, + + // Whether the aliased type has any public static non-operator members + // (used to emit the #region even when no property/readonly-field members survive filtering) + bool HasStaticMemberCandidates, + + // Members + EquatableArray BinaryOperators, + EquatableArray UnaryOperators, + EquatableArray StaticMembers, + EquatableArray InstanceFields, + EquatableArray InstanceProperties, + EquatableArray InstanceMethods, + EquatableArray ForwardedConstructors +); + +internal readonly record struct BinaryOperatorInfo( + string Name, + string LeftTypeFullName, + string RightTypeFullName, + string ReturnTypeFullName, + bool LeftIsAliasedType, + bool RightIsAliasedType, + bool ReturnIsAliasedType +) : IEquatable; + +internal readonly record struct UnaryOperatorInfo( + string Name, + string ReturnTypeFullName, + bool ReturnIsAliasedType +) : IEquatable; + +internal readonly record struct StaticMemberInfo( + string Name, + string TypeFullName, + bool TypeIsAliasedType, + bool IsProperty, + bool IsReadonlyField +) : IEquatable; + +internal readonly record struct InstanceFieldInfo( + string Name, + string TypeFullName, + bool TypeIsAliasedType +) : IEquatable; + +internal readonly record struct InstancePropertyInfo( + string Name, + string TypeFullName, + bool TypeIsAliasedType, + bool HasGetter +) : IEquatable; + +internal readonly record struct InstanceMethodInfo( + string Name, + string ReturnTypeFullName, + bool ReturnsVoid, + bool ReturnIsAliasedType, + bool SkipReturnWrapping, + EquatableArray Parameters +) : IEquatable; + +internal readonly record struct MethodParameterInfo( + string Name, + string TypeFullName, + RefKind RefKind, + bool IsAliasedType +) : IEquatable; + +internal readonly record struct ConstructorInfo( + EquatableArray Parameters +) : IEquatable; + +internal readonly record struct ConstructorParameterInfo( + string Name, + string TypeFullName, + RefKind RefKind, + bool IsParams, + string? DefaultValueLiteral +) : IEquatable; diff --git a/NewType.Generator/AliasModelExtractor.cs b/NewType.Generator/AliasModelExtractor.cs new file mode 100644 index 0000000..ee4e584 --- /dev/null +++ b/NewType.Generator/AliasModelExtractor.cs @@ -0,0 +1,406 @@ +using System.Collections.Immutable; +using System.Globalization; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace newtype.generator; + +/// +/// Extracts all symbol data into an during the transform step, +/// before the incremental pipeline caches the result. After extraction, no Roslyn symbols are retained. +/// +internal static class AliasModelExtractor +{ + public static AliasModel? Extract(GeneratorAttributeSyntaxContext context, ITypeSymbol aliasedType) + { + var typeDecl = (TypeDeclarationSyntax)context.TargetNode; + var typeSymbol = (INamedTypeSymbol)context.TargetSymbol; + + var typeName = typeSymbol.Name; + var ns = typeSymbol.ContainingNamespace; + var namespaceName = ns is {IsGlobalNamespace: false} ? ns.ToDisplayString() : ""; + + var isReadonly = typeDecl.Modifiers.Any(SyntaxKind.ReadOnlyKeyword); + var isClass = typeDecl is ClassDeclarationSyntax + || (typeDecl is RecordDeclarationSyntax rds + && !rds.ClassOrStructKeyword.IsKind(SyntaxKind.StructKeyword)); + var isRecord = typeDecl is RecordDeclarationSyntax; + var isRecordStruct = isRecord && !isClass; + + var aliasedTypeFullName = aliasedType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var aliasedTypeMinimalName = aliasedType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + + var implementsIComparable = ImplementsInterface(aliasedType, "System.IComparable`1"); + var implementsIFormattable = ImplementsInterface(aliasedType, "System.IFormattable"); + + var binaryOperators = ExtractBinaryOperators(aliasedType); + var unaryOperators = ExtractUnaryOperators(aliasedType); + var hasNativeEquality = HasNativeEqualityOperator(aliasedType, binaryOperators); + + var staticMembers = ExtractStaticMembers(aliasedType); + var hasStaticMemberCandidates = HasStaticMemberCandidates(aliasedType); + var instanceFields = ExtractInstanceFields(aliasedType); + var instanceProperties = ExtractInstanceProperties(aliasedType); + var instanceMethods = ExtractInstanceMethods(aliasedType); + var constructors = ExtractForwardableConstructors(typeSymbol, aliasedType); + + var typeDisplayString = typeSymbol.ToDisplayString(); + + return new AliasModel( + TypeName: typeName, + Namespace: namespaceName, + DeclaredAccessibility: typeSymbol.DeclaredAccessibility, + IsReadonly: isReadonly, + IsClass: isClass, + IsRecord: isRecord, + IsRecordStruct: isRecordStruct, + AliasedTypeFullName: aliasedTypeFullName, + AliasedTypeMinimalName: aliasedTypeMinimalName, + AliasedTypeSpecialType: aliasedType.SpecialType, + AliasedTypeIsValueType: aliasedType.IsValueType, + ImplementsIComparable: implementsIComparable, + ImplementsIFormattable: implementsIFormattable, + HasNativeEqualityOperator: hasNativeEquality, + TypeDisplayString: typeDisplayString, + HasStaticMemberCandidates: hasStaticMemberCandidates, + BinaryOperators: binaryOperators, + UnaryOperators: unaryOperators, + StaticMembers: staticMembers, + InstanceFields: instanceFields, + InstanceProperties: instanceProperties, + InstanceMethods: instanceMethods, + ForwardedConstructors: constructors + ); + } + + private static EquatableArray ExtractBinaryOperators(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is IMethodSymbol {MethodKind: MethodKind.UserDefinedOperator, Parameters.Length: 2} method) + { + builder.Add(new BinaryOperatorInfo( + Name: method.Name, + LeftTypeFullName: method.Parameters[0].Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + RightTypeFullName: method.Parameters[1].Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + ReturnTypeFullName: method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + LeftIsAliasedType: SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, aliasedType), + RightIsAliasedType: SymbolEqualityComparer.Default.Equals(method.Parameters[1].Type, aliasedType), + ReturnIsAliasedType: SymbolEqualityComparer.Default.Equals(method.ReturnType, aliasedType) + )); + } + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static EquatableArray ExtractUnaryOperators(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is IMethodSymbol {MethodKind: MethodKind.UserDefinedOperator, Parameters.Length: 1} method) + { + builder.Add(new UnaryOperatorInfo( + Name: method.Name, + ReturnTypeFullName: method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + ReturnIsAliasedType: SymbolEqualityComparer.Default.Equals(method.ReturnType, aliasedType) + )); + } + } + + return new EquatableArray(builder.ToImmutable()); + } + + /// + /// Matches the old code's filter: any public static non-operator member that is IPropertySymbol or IFieldSymbol. + /// This includes const fields which don't survive the readonly filter but cause the empty region to be emitted. + /// + private static bool HasStaticMemberCandidates(ITypeSymbol aliasedType) + { + foreach (var member in aliasedType.GetMembers()) + { + if (!member.IsStatic || member.DeclaredAccessibility != Accessibility.Public) + continue; + if (member.Name.StartsWith("op_")) + continue; + if (member is IPropertySymbol or IFieldSymbol) + return true; + } + return false; + } + + private static EquatableArray ExtractStaticMembers(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (!member.IsStatic || member.DeclaredAccessibility != Accessibility.Public) + continue; + if (member.Name.StartsWith("op_")) + continue; + + if (member is IPropertySymbol {GetMethod: not null} prop) + { + builder.Add(new StaticMemberInfo( + Name: prop.Name, + TypeFullName: prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + TypeIsAliasedType: SymbolEqualityComparer.Default.Equals(prop.Type, aliasedType), + IsProperty: true, + IsReadonlyField: false + )); + } + else if (member is IFieldSymbol {IsReadOnly: true} field) + { + builder.Add(new StaticMemberInfo( + Name: field.Name, + TypeFullName: field.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + TypeIsAliasedType: SymbolEqualityComparer.Default.Equals(field.Type, aliasedType), + IsProperty: false, + IsReadonlyField: true + )); + } + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static EquatableArray ExtractInstanceFields(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is IFieldSymbol {IsStatic: false, IsImplicitlyDeclared: false, DeclaredAccessibility: Accessibility.Public} field) + { + builder.Add(new InstanceFieldInfo( + Name: field.Name, + TypeFullName: field.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + TypeIsAliasedType: SymbolEqualityComparer.Default.Equals(field.Type, aliasedType) + )); + } + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static EquatableArray ExtractInstanceProperties(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is IPropertySymbol {IsStatic: false, DeclaredAccessibility: Accessibility.Public, IsIndexer: false, GetMethod: not null} prop) + { + builder.Add(new InstancePropertyInfo( + Name: prop.Name, + TypeFullName: prop.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + TypeIsAliasedType: SymbolEqualityComparer.Default.Equals(prop.Type, aliasedType), + HasGetter: prop.GetMethod != null + )); + } + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static EquatableArray ExtractInstanceMethods(ITypeSymbol aliasedType) + { + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is not IMethodSymbol method) + continue; + if (method.IsStatic || method.DeclaredAccessibility != Accessibility.Public) + continue; + if (method.MethodKind != MethodKind.Ordinary) + continue; + if (!method.CanBeReferencedByName) + continue; + if (method.Name.StartsWith("get_") || method.Name.StartsWith("set_")) + continue; + if (method.Name is "GetHashCode" or "Equals" or "ToString" or "CompareTo") + continue; + + var skipReturnWrapping = aliasedType.IsValueType && + aliasedType.SpecialType != SpecialType.None; + + var paramBuilder = ImmutableArray.CreateBuilder(); + foreach (var p in method.Parameters) + { + paramBuilder.Add(new MethodParameterInfo( + Name: p.Name, + TypeFullName: p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + RefKind: p.RefKind, + IsAliasedType: SymbolEqualityComparer.Default.Equals(p.Type, aliasedType) + )); + } + + builder.Add(new InstanceMethodInfo( + Name: method.Name, + ReturnTypeFullName: method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + ReturnsVoid: method.ReturnsVoid, + ReturnIsAliasedType: SymbolEqualityComparer.Default.Equals(method.ReturnType, aliasedType), + SkipReturnWrapping: skipReturnWrapping, + Parameters: new EquatableArray(paramBuilder.ToImmutable()) + )); + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static EquatableArray ExtractForwardableConstructors( + INamedTypeSymbol typeSymbol, ITypeSymbol aliasedType) + { + var userSignatures = GetUserDefinedConstructorSignatures(typeSymbol); + var builder = ImmutableArray.CreateBuilder(); + + foreach (var member in aliasedType.GetMembers()) + { + if (member is not IMethodSymbol m) + continue; + if (m.MethodKind != MethodKind.Constructor || m.IsStatic) + continue; + if (m.DeclaredAccessibility != Accessibility.Public || m.IsImplicitlyDeclared) + continue; + if (m.Parameters.Length == 0) + continue; + + // Skip copy constructor (single param of the aliased type itself) + if (m.Parameters.Length == 1 && + SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, aliasedType)) + continue; + + // Skip constructors with pointer parameters + var hasPointer = false; + foreach (var p in m.Parameters) + { + if (p.Type.TypeKind == TypeKind.Pointer) + { + hasPointer = true; + break; + } + } + if (hasPointer) continue; + + // Skip if user already defined a constructor with the same signature + var sig = GetConstructorSignature(m); + if (userSignatures.Contains(sig)) + continue; + + var paramBuilder = ImmutableArray.CreateBuilder(); + foreach (var p in m.Parameters) + { + paramBuilder.Add(new ConstructorParameterInfo( + Name: p.Name, + TypeFullName: p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + RefKind: p.RefKind, + IsParams: p.IsParams, + DefaultValueLiteral: p.HasExplicitDefaultValue ? FormatDefaultValue(p) : null + )); + } + + builder.Add(new ConstructorInfo( + Parameters: new EquatableArray(paramBuilder.ToImmutable()) + )); + } + + return new EquatableArray(builder.ToImmutable()); + } + + private static HashSet GetUserDefinedConstructorSignatures(INamedTypeSymbol typeSymbol) + { + var signatures = new HashSet(); + foreach (var member in typeSymbol.GetMembers()) + { + if (member is IMethodSymbol {MethodKind: MethodKind.Constructor, IsImplicitlyDeclared: false} ctor) + { + signatures.Add(GetConstructorSignature(ctor)); + } + } + return signatures; + } + + private static string GetConstructorSignature(IMethodSymbol ctor) + { + return string.Join(",", ctor.Parameters.Select(p => + { + var refModifier = p.RefKind switch + { + RefKind.Ref => "ref ", + RefKind.Out => "out ", + RefKind.In => "in ", + _ => "" + }; + return refModifier + p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + })); + } + + private static string FormatDefaultValue(IParameterSymbol param) + { + var value = param.ExplicitDefaultValue; + + if (value is null) + return "default"; + + if (value is string s) + return SymbolDisplay.FormatLiteral(s, true); + + if (value is char c) + return SymbolDisplay.FormatLiteral(c, true); + + if (value is bool b) + return b ? "true" : "false"; + + if (value is float f) + return f.ToString("R", CultureInfo.InvariantCulture) + "f"; + + if (value is double d) + return d.ToString("R", CultureInfo.InvariantCulture) + "d"; + + if (value is decimal m) + return m.ToString(CultureInfo.InvariantCulture) + "m"; + + return string.Format(CultureInfo.InvariantCulture, "{0}", value); + } + + private static bool HasNativeEqualityOperator(ITypeSymbol type, + EquatableArray binaryOperators) + { + // User-defined operator == + foreach (var op in binaryOperators) + { + if (op.Name == "op_Equality") + return true; + } + + // Built-in == for primitives + return type.SpecialType is + SpecialType.System_Boolean or + SpecialType.System_Byte or SpecialType.System_SByte or + SpecialType.System_Int16 or SpecialType.System_UInt16 or + SpecialType.System_Int32 or SpecialType.System_UInt32 or + SpecialType.System_Int64 or SpecialType.System_UInt64 or + SpecialType.System_Single or SpecialType.System_Double or + SpecialType.System_Decimal or SpecialType.System_Char or + SpecialType.System_String; + } + + private static bool ImplementsInterface(ITypeSymbol type, string interfaceFullName) + { + if (interfaceFullName.Contains("`")) + { + var baseName = interfaceFullName.Split('`')[0]; + return type.AllInterfaces.Any(i => + i.OriginalDefinition.ToDisplayString().StartsWith(baseName)); + } + + return type.AllInterfaces.Any(i => i.ToDisplayString() == interfaceFullName); + } +} diff --git a/NewType.Generator/EquatableArray.cs b/NewType.Generator/EquatableArray.cs new file mode 100644 index 0000000..c419ca7 --- /dev/null +++ b/NewType.Generator/EquatableArray.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; + +namespace newtype.generator; + +/// +/// An immutable array wrapper with value-based equality (SequenceEqual). +/// Required because uses reference equality. +/// +internal readonly struct EquatableArray : IEquatable>, IEnumerable + where T : IEquatable +{ + private readonly ImmutableArray _array; + + public EquatableArray(ImmutableArray array) + { + _array = array; + } + + public ImmutableArray Array => _array.IsDefault ? ImmutableArray.Empty : _array; + + public int Length => Array.Length; + + public T this[int index] => Array[index]; + + public bool Equals(EquatableArray other) + { + var a = Array; + var b = other.Array; + + if (a.Length != b.Length) + return false; + + for (var i = 0; i < a.Length; i++) + { + if (!a[i].Equals(b[i])) + return false; + } + + return true; + } + + public override bool Equals(object? obj) => obj is EquatableArray other && Equals(other); + + public override int GetHashCode() + { + var a = Array; + var hash = 17; + for (var i = 0; i < a.Length; i++) + { + unchecked + { + hash = hash * 31 + a[i].GetHashCode(); + } + } + return hash; + } + + public static bool operator ==(EquatableArray left, EquatableArray right) => left.Equals(right); + public static bool operator !=(EquatableArray left, EquatableArray right) => !left.Equals(right); + + public ImmutableArray.Enumerator GetEnumerator() => Array.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Array).GetEnumerator(); +} diff --git a/NewType.Generator/NUGET.md b/NewType.Generator/NUGET.md index e3482c8..0840dfb 100644 --- a/NewType.Generator/NUGET.md +++ b/NewType.Generator/NUGET.md @@ -1,9 +1,8 @@ -# `newtype` (Distinct Type Aliases for C#) +# `newtype` *(Distinct Type Aliases for C#)* -![logo, a stylized N with a red and Blue half](https://raw.githubusercontent.com/outfox/newtype/main/logo.svg) +![logo, a stylized N with a red and blue half](https://raw.githubusercontent.com/outfox/newtype/main/logo.svg) -A source generator that creates distinct type aliases with full operator forwarding. Inspired by Haskell's `newtype` and -F#'s type abbreviations. `newtype` works for a healthy number of types - many primitives, structs, records, classes work out of the box. +This package is a source generator that creates distinct type aliases with full operator and constructor forwarding. Inspired by Haskell's `newtype` and F#'s type abbreviations. `newtype` works for a healthy number of types - many primitives, structs, records, classes work out of the box. ## Installation @@ -13,18 +12,37 @@ dotnet add package newtype ## Usage -#### Basic: strongly typed `string`-names, `int`-IDs, and `int`-counts +### Basic: typed IDs and counts ```csharp using newtype; +[newtype] +public readonly partial struct TableId; + [newtype] -public readonly partial struct Pizzas; +public readonly partial struct PizzasEaten; [newtype] public readonly partial struct Fullness; + +class Guest +{ + TableId table = "Table 1"; + PizzasEaten pizzasEaten; + Fullness fullness; + + public void fillEmUp(Fullness threshold) + { + while (fullness < threshold) + { + pizzasEaten++; + fullness += 0.1; + } + } +} ``` -#### Typical: quantities backed by the same data type. +### Typical: quantities backed by the same data type but distinct domain semantics *For example, forces, velocities, positions, etc. all lose their semantics when expressed as `Vector3`* ```csharp using System.Numerics; @@ -51,7 +69,7 @@ Console.WriteLine(p.Length()); // 3.74... Position updated = p + v * deltaTime; // Implicit conversion both ways -Vector3 vec = p; // Position → Vector3 +Vector3 vec = p; // Position → Vector3 Position pos = new Vector3(); // Vector3 → Position ``` @@ -143,7 +161,7 @@ wrapper entirely in release builds. The generated code has the same performance dotnet add package newtype ``` -## Viewing Generated Code +## Viewing Generated Code< If you only knew now what you didn't know then. > Enable generated file output in your project: diff --git a/NewType.Tests/GeneratorTests/ConstructorSignatureTests.cs b/NewType.Tests/GeneratorTests/ConstructorSignatureTests.cs new file mode 100644 index 0000000..4e453be --- /dev/null +++ b/NewType.Tests/GeneratorTests/ConstructorSignatureTests.cs @@ -0,0 +1,33 @@ +using Xunit; + +namespace newtype.tests; + +public class ConstructorSignatureTests +{ + [Fact] + public void Ref_And_Value_Constructors_Are_Not_Treated_As_Duplicates() + { + const string source = """ + using newtype; + + public class Parser + { + public string Result { get; } + public Parser(string input) => Result = input; + public Parser(ref string input) { Result = input; input = ""; } + } + + [newtype] + public readonly partial struct ParserAlias; + """; + + // Both constructors should be forwarded + var result = GeneratorTestHelper.RunGenerator(source); + var text = result.Results[0].GeneratedSources + .Single(s => s.HintName.EndsWith("ParserAlias.g.cs")) + .SourceText.ToString(); + + Assert.Contains("public ParserAlias(string input)", text); + Assert.Contains("public ParserAlias(ref string input)", text); + } +} diff --git a/NewType.Tests/GeneratorTests/DefaultValueFormattingTests.cs b/NewType.Tests/GeneratorTests/DefaultValueFormattingTests.cs new file mode 100644 index 0000000..256c17c --- /dev/null +++ b/NewType.Tests/GeneratorTests/DefaultValueFormattingTests.cs @@ -0,0 +1,92 @@ +using System.Globalization; +using Xunit; + +namespace newtype.tests; + +public class DefaultValueFormattingTests +{ + [Fact] + public void Decimal_Default_Value_Uses_Invariant_Culture() + { + const string source = """ + using newtype; + + public class Currency + { + public decimal Amount { get; } + public Currency(decimal amount = 1.5m) => Amount = amount; + } + + [newtype] + public readonly partial struct Price; + """; + + // Run under de-DE where the decimal separator is a comma + var previous = CultureInfo.CurrentCulture; + try + { + CultureInfo.CurrentCulture = CultureInfo.GetCultureInfo("de-DE"); + + var result = GeneratorTestHelper.RunGenerator(source); + var text = result.Results[0].GeneratedSources + .Single(s => s.HintName.EndsWith("Price.g.cs")) + .SourceText.ToString(); + + // Must use period even under de-DE culture + Assert.Contains("1.5m", text); + Assert.DoesNotContain("1,5m", text); + } + finally + { + CultureInfo.CurrentCulture = previous; + } + } + + [Fact] + public void Float_Default_Value_Uses_Invariant_Culture() + { + const string source = """ + using newtype; + + public class Sensor + { + public float Reading { get; } + public Sensor(float reading = 2.5f) => Reading = reading; + } + + [newtype] + public readonly partial struct SensorAlias; + """; + + var result = GeneratorTestHelper.RunGenerator(source); + var text = result.Results[0].GeneratedSources + .Single(s => s.HintName.EndsWith("SensorAlias.g.cs")) + .SourceText.ToString(); + + Assert.Contains("2.5f", text); + } + + [Fact] + public void Double_Default_Value_Uses_Invariant_Culture() + { + const string source = """ + using newtype; + + public class Measurement + { + public double Result { get; } + public Measurement(double result = 3.14d) => Result = result; + } + + [newtype] + public readonly partial struct MeasurementAlias; + """; + + var result = GeneratorTestHelper.RunGenerator(source); + var text = result.Results[0].GeneratedSources + .Single(s => s.HintName.EndsWith("MeasurementAlias.g.cs")) + .SourceText.ToString(); + + Assert.Contains("3.14d", text); + } +} diff --git a/NewType.Tests/GeneratorTests/GeneratorOutputTests.cs b/NewType.Tests/GeneratorTests/GeneratorOutputTests.cs new file mode 100644 index 0000000..3fc8b14 --- /dev/null +++ b/NewType.Tests/GeneratorTests/GeneratorOutputTests.cs @@ -0,0 +1,185 @@ +using Basic.Reference.Assemblies; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using newtype.generator; +using Xunit; + +namespace newtype.tests; + +public class GeneratorOutputTests +{ + [Fact] + public void Generates_Output_For_Int_Alias() + { + const string source = """ + using newtype; + + [newtype] + public readonly partial struct TestId; + """; + + var result = GeneratorTestHelper.RunGenerator(source); + + // Attribute source + alias source + Assert.Equal(2, result.GeneratedTrees.Length); + + var generatedSources = result.Results[0].GeneratedSources; + + var aliasSource = generatedSources.Single(s => s.HintName.EndsWith("TestId.g.cs")); + var text = aliasSource.SourceText.ToString(); + + Assert.Contains("private readonly int _value;", text); + Assert.Contains("public TestId(int value)", text); + Assert.Contains("operator +", text); + Assert.Contains("operator ==", text); + Assert.Contains("IComparable", text); + Assert.Contains("IEquatable", text); + } + + [Fact] + public void Generates_Output_For_String_Alias() + { + const string source = """ + using newtype; + + [newtype] + public readonly partial struct Label; + """; + + var result = GeneratorTestHelper.RunGenerator(source); + + Assert.Equal(2, result.GeneratedTrees.Length); + + var generatedSources = result.Results[0].GeneratedSources; + + var aliasSource = generatedSources.Single(s => s.HintName.EndsWith("Label.g.cs")); + var text = aliasSource.SourceText.ToString(); + + Assert.Contains("private readonly string _value;", text); + Assert.Contains(@"_value?.ToString() ?? """"", text); + Assert.Contains("IComparable