44using System . IO ;
55using System . Linq ;
66using Microsoft . CodeAnalysis ;
7+ using Semmle . Util ;
78using Semmle . Extraction . CSharp . Entities ;
89
910namespace Semmle . Extraction . CSharp
@@ -164,6 +165,7 @@ public static void BuildTypeId(this ITypeSymbol type, Context cx, EscapingTextWr
164165 case TypeKind . Enum :
165166 case TypeKind . Delegate :
166167 case TypeKind . Error :
168+ case TypeKind . Extension :
167169 var named = ( INamedTypeSymbol ) type ;
168170 named . BuildNamedTypeId ( cx , trapFile , symbolBeingDefined , constructUnderlyingTupleType ) ;
169171 return ;
@@ -275,6 +277,20 @@ private static void BuildFunctionPointerTypeId(this IFunctionPointerTypeSymbol f
275277 public static IEnumerable < IFieldSymbol ? > GetTupleElementsMaybeNull ( this INamedTypeSymbol type ) =>
276278 type . TupleElements ;
277279
280+ private static void BuildExtensionTypeId ( this INamedTypeSymbol named , Context cx , EscapingTextWriter trapFile )
281+ {
282+ trapFile . Write ( "extension(" ) ;
283+ if ( named . ExtensionMarkerName is not null )
284+ {
285+ trapFile . Write ( named . ExtensionMarkerName ) ;
286+ }
287+ else
288+ {
289+ trapFile . Write ( "unknown" ) ;
290+ }
291+ trapFile . Write ( ")" ) ;
292+ }
293+
278294 private static void BuildQualifierAndName ( INamedTypeSymbol named , Context cx , EscapingTextWriter trapFile , ISymbol symbolBeingDefined )
279295 {
280296 if ( named . ContainingType is not null )
@@ -289,8 +305,18 @@ private static void BuildQualifierAndName(INamedTypeSymbol named, Context cx, Es
289305 named . ContainingNamespace . BuildNamespace ( cx , trapFile ) ;
290306 }
291307
292- var name = named . IsFileLocal ? named . MetadataName : named . Name ;
293- trapFile . Write ( name ) ;
308+ if ( named . IsFileLocal )
309+ {
310+ trapFile . Write ( named . MetadataName ) ;
311+ }
312+ else if ( named . IsExtension )
313+ {
314+ named . BuildExtensionTypeId ( cx , trapFile ) ;
315+ }
316+ else
317+ {
318+ trapFile . Write ( named . Name ) ;
319+ }
294320 }
295321
296322 private static void BuildTupleId ( INamedTypeSymbol named , Context cx , EscapingTextWriter trapFile , ISymbol symbolBeingDefined )
@@ -391,6 +417,7 @@ public static void BuildDisplayName(this ITypeSymbol type, Context cx, TextWrite
391417 case TypeKind . Enum :
392418 case TypeKind . Delegate :
393419 case TypeKind . Error :
420+ case TypeKind . Extension :
394421 var named = ( INamedTypeSymbol ) type ;
395422 named . BuildNamedTypeDisplayName ( cx , trapFile , constructUnderlyingTupleType ) ;
396423 return ;
@@ -465,6 +492,20 @@ public static void BuildFunctionPointerSignature(IFunctionPointerTypeSymbol funp
465492 private static void BuildFunctionPointerTypeDisplayName ( this IFunctionPointerTypeSymbol funptr , Context cx , TextWriter trapFile ) =>
466493 BuildFunctionPointerSignature ( funptr , trapFile , s => s . BuildDisplayName ( cx , trapFile ) ) ;
467494
495+ private static void BuildExtensionTypeDisplayName ( this INamedTypeSymbol named , Context cx , TextWriter trapFile )
496+ {
497+ trapFile . Write ( "extension(" ) ;
498+ if ( named . ExtensionParameter ? . Type is ITypeSymbol type )
499+ {
500+ type . BuildDisplayName ( cx , trapFile ) ;
501+ }
502+ else
503+ {
504+ trapFile . Write ( "unknown" ) ;
505+ }
506+ trapFile . Write ( ")" ) ;
507+ }
508+
468509 private static void BuildNamedTypeDisplayName ( this INamedTypeSymbol namedType , Context cx , TextWriter trapFile , bool constructUnderlyingTupleType )
469510 {
470511 if ( ! constructUnderlyingTupleType && namedType . IsTupleType )
@@ -484,6 +525,12 @@ private static void BuildNamedTypeDisplayName(this INamedTypeSymbol namedType, C
484525 return ;
485526 }
486527
528+ if ( namedType . IsExtension )
529+ {
530+ namedType . BuildExtensionTypeDisplayName ( cx , trapFile ) ;
531+ return ;
532+ }
533+
487534 if ( namedType . IsAnonymousType )
488535 {
489536 namedType . BuildAnonymousName ( cx , trapFile ) ;
@@ -596,6 +643,84 @@ public static bool IsSourceDeclaration(this IParameterSymbol parameter)
596643 return true ;
597644 }
598645
646+ /// <summary>
647+ /// Return true if this method is a compiler-generated extension method.
648+ /// </summary>
649+ public static bool IsCompilerGeneratedExtensionMethod ( this IMethodSymbol method ) =>
650+ method . TryGetExtensionMethod ( ) is not null ;
651+
652+ /// <summary>
653+ /// Returns the extension method corresponding to this compiler-generated extension method, if it exists.
654+ /// </summary>
655+ public static IMethodSymbol ? TryGetExtensionMethod ( this IMethodSymbol method )
656+ {
657+ if ( method . IsImplicitlyDeclared && method . ContainingSymbol is INamedTypeSymbol containingType )
658+ {
659+ // Extension types are declared within the same type as the generated
660+ // extension method implementation.
661+ var extensions = containingType . GetMembers ( )
662+ . OfType < INamedTypeSymbol > ( )
663+ . Where ( t => t . IsExtension ) ;
664+ // Find the (possibly unbound) original extension method that maps to this implementation (if any).
665+ var unboundDeclaration = extensions . SelectMany ( e => e . GetMembers ( ) )
666+ . OfType < IMethodSymbol > ( )
667+ . FirstOrDefault ( m => SymbolEqualityComparer . Default . Equals ( m . AssociatedExtensionImplementation , method . ConstructedFrom ) ) ;
668+
669+ var isFullyConstructed = method . IsBoundGenericMethod ( ) ;
670+ if ( isFullyConstructed && unboundDeclaration ? . ContainingType is INamedTypeSymbol extensionType )
671+ {
672+ try
673+ {
674+ // Use the type arguments from the constructed extension method to construct the extension type.
675+ var arguments = method . TypeArguments . ToArray ( ) ;
676+ var ( extensionTypeArguments , extensionMethodArguments ) = arguments . SplitAt ( extensionType . TypeParameters . Length ) ;
677+
678+ // Construct the extension type.
679+ var boundExtensionType = extensionType . IsUnboundGenericType ( )
680+ ? extensionType . Construct ( extensionTypeArguments . ToArray ( ) )
681+ : extensionType ;
682+
683+ // Find the extension method declaration within the constructed extension type.
684+ var extensionDeclaration = boundExtensionType . GetMembers ( )
685+ . OfType < IMethodSymbol > ( )
686+ . First ( c => SymbolEqualityComparer . Default . Equals ( c . OriginalDefinition , unboundDeclaration ) ) ;
687+
688+ // If the extension declaration is unbound apply the remaning type arguments and construct it.
689+ return extensionDeclaration . IsUnboundGenericMethod ( )
690+ ? extensionDeclaration . Construct ( extensionMethodArguments . ToArray ( ) )
691+ : extensionDeclaration ;
692+ }
693+ catch
694+ {
695+ // If anything goes wrong, fall back to the unbound declaration.
696+ return unboundDeclaration ;
697+ }
698+ }
699+ else
700+ {
701+ return unboundDeclaration ;
702+ }
703+ }
704+ return null ;
705+ }
706+
707+ /// <summary>
708+ /// Returns true if this method is an unbound generic method.
709+ /// </summary>
710+ public static bool IsUnboundGenericMethod ( this IMethodSymbol method ) =>
711+ method . IsGenericMethod && SymbolEqualityComparer . Default . Equals ( method . ConstructedFrom , method ) ;
712+
713+ /// <summary>
714+ /// Returns true if this method is a bound generic method.
715+ /// </summary>
716+ public static bool IsBoundGenericMethod ( this IMethodSymbol method ) => method . IsGenericMethod && ! method . IsUnboundGenericMethod ( ) ;
717+
718+ /// <summary>
719+ /// Returns true if this type is an unbound generic type.
720+ /// </summary>
721+ public static bool IsUnboundGenericType ( this INamedTypeSymbol type ) =>
722+ type . IsGenericType && SymbolEqualityComparer . Default . Equals ( type . ConstructedFrom , type ) ;
723+
599724 /// <summary>
600725 /// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base
601726 /// types of type parameters as well as `object` base types.
@@ -692,5 +817,35 @@ public static bool ShouldExtractSymbol(this ISymbol symbol)
692817 /// </summary>
693818 public static IEnumerable < T > ExtractionCandidates < T > ( this IEnumerable < T > symbols ) where T : ISymbol =>
694819 symbols . Where ( symbol => symbol . ShouldExtractSymbol ( ) ) ;
820+
821+ /// <summary>
822+ /// Returns the parameter kind for this parameter symbol, e.g. `ref`, `out`, `params`, etc.
823+ /// </summary>
824+ public static Parameter . Kind GetParameterKind ( this IParameterSymbol parameter )
825+ {
826+ switch ( parameter . RefKind )
827+ {
828+ case RefKind . Out :
829+ return Parameter . Kind . Out ;
830+ case RefKind . Ref :
831+ return Parameter . Kind . Ref ;
832+ case RefKind . In :
833+ return Parameter . Kind . In ;
834+ case RefKind . RefReadOnlyParameter :
835+ return Parameter . Kind . RefReadOnly ;
836+ default :
837+ if ( parameter . IsParams )
838+ return Parameter . Kind . Params ;
839+
840+ if ( parameter . Ordinal == 0 )
841+ {
842+ if ( parameter . ContainingSymbol is IMethodSymbol method && method . IsExtensionMethod )
843+ {
844+ return Parameter . Kind . This ;
845+ }
846+ }
847+ return Parameter . Kind . None ;
848+ }
849+ }
695850 }
696851}
0 commit comments