summaryrefslogtreecommitdiff
path: root/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs
blob: 04b2293dde876d7c78811023c1a401a12c7b881c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team
// 
// Permission is hereby granted, free of charge, to any person obtaining a copy of this
// software and associated documentation files (the "Software"), to deal in the Software
// without restriction, including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
// to whom the Software is furnished to do so, subject to the following conditions:
// 
// The above copyright notice and this permission notice shall be included in all copies or
// substantial portions of the Software.
// 
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using ICSharpCode.Decompiler;
using ICSharpCode.Decompiler.ILAst;
using ICSharpCode.NRefactory.CSharp;
using ICSharpCode.NRefactory.PatternMatching;
using Mono.Cecil;

namespace ICSharpCode.Decompiler.Ast.Transforms
{
	/// <summary>
	/// Converts "new Action(obj, ldftn(func))" into "new Action(obj.func)".
	/// For anonymous methods, creates an AnonymousMethodExpression.
	/// Also gets rid of any "Display Classes" left over after inlining an anonymous method.
	/// </summary>
	public class DelegateConstruction : ContextTrackingVisitor<object>
	{
		internal sealed class Annotation
		{
			/// <summary>
			/// ldftn or ldvirtftn?
			/// </summary>
			public readonly bool IsVirtual;
			
			public Annotation(bool isVirtual)
			{
				IsVirtual = isVirtual;
			}
		}
		
		internal sealed class CapturedVariableAnnotation
		{
		}
		
		List<string> currentlyUsedVariableNames = new List<string>();
		
		public DelegateConstruction(DecompilerContext context) : base(context)
		{
		}
		
		public override object VisitObjectCreateExpression(ObjectCreateExpression objectCreateExpression, object data)
		{
			if (objectCreateExpression.Arguments.Count == 2) {
				Expression obj = objectCreateExpression.Arguments.First();
				Expression func = objectCreateExpression.Arguments.Last();
				Annotation annotation = func.Annotation<Annotation>();
				if (annotation != null) {
					IdentifierExpression methodIdent = (IdentifierExpression)((InvocationExpression)func).Arguments.Single();
					MethodReference method = methodIdent.Annotation<MethodReference>();
					if (method != null) {
						if (HandleAnonymousMethod(objectCreateExpression, obj, method))
							return null;
						// Perform the transformation to "new Action(obj.func)".
						obj.Remove();
						methodIdent.Remove();
						if (!annotation.IsVirtual && obj is ThisReferenceExpression) {
							// maybe it's getting the pointer of a base method?
							if (method.DeclaringType.GetElementType() != context.CurrentType) {
								obj = new BaseReferenceExpression();
							}
						}
						if (!annotation.IsVirtual && obj is NullReferenceExpression && !method.HasThis) {
							// We're loading a static method.
							// However it is possible to load extension methods with an instance, so we compare the number of arguments:
							bool isExtensionMethod = false;
							TypeReference delegateType = objectCreateExpression.Type.Annotation<TypeReference>();
							if (delegateType != null) {
								TypeDefinition delegateTypeDef = delegateType.Resolve();
								if (delegateTypeDef != null) {
									MethodDefinition invokeMethod = delegateTypeDef.Methods.FirstOrDefault(m => m.Name == "Invoke");
									if (invokeMethod != null) {
										isExtensionMethod = (invokeMethod.Parameters.Count + 1 == method.Parameters.Count);
									}
								}
							}
							if (!isExtensionMethod) {
								obj = new TypeReferenceExpression { Type = AstBuilder.ConvertType(method.DeclaringType) };
							}
						}
						// now transform the identifier into a member reference
						MemberReferenceExpression mre = new MemberReferenceExpression();
						mre.Target = obj;
						mre.MemberName = methodIdent.Identifier;
						methodIdent.TypeArguments.MoveTo(mre.TypeArguments);
						mre.AddAnnotation(method);
						objectCreateExpression.Arguments.Clear();
						objectCreateExpression.Arguments.Add(mre);
						return null;
					}
				}
			}
			return base.VisitObjectCreateExpression(objectCreateExpression, data);
		}
		
		internal static bool IsAnonymousMethod(DecompilerContext context, MethodDefinition method)
		{
			if (method == null || !(method.HasGeneratedName() || method.Name.Contains("$")))
				return false;
			if (!(method.IsCompilerGenerated() || IsPotentialClosure(context, method.DeclaringType)))
				return false;
			return true;
		}
		
		bool HandleAnonymousMethod(ObjectCreateExpression objectCreateExpression, Expression target, MethodReference methodRef)
		{
			if (!context.Settings.AnonymousMethods)
				return false; // anonymous method decompilation is disabled
			if (target != null && !(target is IdentifierExpression || target is ThisReferenceExpression || target is NullReferenceExpression))
				return false; // don't copy arbitrary expressions, deal with identifiers only
			
			// Anonymous methods are defined in the same assembly
			MethodDefinition method = methodRef.ResolveWithinSameModule();
			if (!IsAnonymousMethod(context, method))
				return false;
			
			// Create AnonymousMethodExpression and prepare parameters
			AnonymousMethodExpression ame = new AnonymousMethodExpression();
			ame.CopyAnnotationsFrom(objectCreateExpression); // copy ILRanges etc.
			ame.RemoveAnnotations<MethodReference>(); // remove reference to delegate ctor
			ame.AddAnnotation(method); // add reference to anonymous method
			ame.Parameters.AddRange(AstBuilder.MakeParameters(method, isLambda: true));
			ame.HasParameterList = true;
			
			// rename variables so that they don't conflict with the parameters:
			foreach (ParameterDeclaration pd in ame.Parameters) {
				EnsureVariableNameIsAvailable(objectCreateExpression, pd.Name);
			}
			
			// Decompile the anonymous method:
			
			DecompilerContext subContext = context.Clone();
			subContext.CurrentMethod = method;
			subContext.CurrentMethodIsAsync = false;
			subContext.ReservedVariableNames.AddRange(currentlyUsedVariableNames);
			BlockStatement body = AstMethodBodyBuilder.CreateMethodBody(method, subContext, ame.Parameters);
			TransformationPipeline.RunTransformationsUntil(body, v => v is DelegateConstruction, subContext);
			body.AcceptVisitor(this, null);
			
			
			bool isLambda = false;
			if (ame.Parameters.All(p => p.ParameterModifier == ParameterModifier.None)) {
				isLambda = (body.Statements.Count == 1 && body.Statements.Single() is ReturnStatement);
			}
			// Remove the parameter list from an AnonymousMethodExpression if the original method had no names,
			// and the parameters are not used in the method body
			if (!isLambda && method.Parameters.All(p => string.IsNullOrEmpty(p.Name))) {
				var parameterReferencingIdentifiers =
					from ident in body.Descendants.OfType<IdentifierExpression>()
					let v = ident.Annotation<ILVariable>()
					where v != null && v.IsParameter && method.Parameters.Contains(v.OriginalParameter)
					select ident;
				if (!parameterReferencingIdentifiers.Any()) {
					ame.Parameters.Clear();
					ame.HasParameterList = false;
				}
			}
			
			// Replace all occurrences of 'this' in the method body with the delegate's target:
			foreach (AstNode node in body.Descendants) {
				if (node is ThisReferenceExpression)
					node.ReplaceWith(target.Clone());
			}
			Expression replacement;
			if (isLambda) {
				LambdaExpression lambda = new LambdaExpression();
				lambda.CopyAnnotationsFrom(ame);
				ame.Parameters.MoveTo(lambda.Parameters);
				Expression returnExpr = ((ReturnStatement)body.Statements.Single()).Expression;
				returnExpr.Remove();
				lambda.Body = returnExpr;
				replacement = lambda;
			} else {
				ame.Body = body;
				replacement = ame;
			}
			var expectedType = objectCreateExpression.Annotation<TypeInformation>().ExpectedType.Resolve();
			if (expectedType != null && !expectedType.IsDelegate()) {
				var simplifiedDelegateCreation = (ObjectCreateExpression)objectCreateExpression.Clone();
				simplifiedDelegateCreation.Arguments.Clear();
				simplifiedDelegateCreation.Arguments.Add(replacement);
				replacement = simplifiedDelegateCreation;
			}
			objectCreateExpression.ReplaceWith(replacement);
			return true;
		}
		
		internal static bool IsPotentialClosure(DecompilerContext context, TypeDefinition potentialDisplayClass)
		{
			if (potentialDisplayClass == null || !potentialDisplayClass.IsCompilerGeneratedOrIsInCompilerGeneratedClass())
				return false;
			// check that methodContainingType is within containingType
			while (potentialDisplayClass != context.CurrentType) {
				potentialDisplayClass = potentialDisplayClass.DeclaringType;
				if (potentialDisplayClass == null)
					return false;
			}
			return true;
		}
		
		public override object VisitInvocationExpression(InvocationExpression invocationExpression, object data)
		{
			if (context.Settings.ExpressionTrees && ExpressionTreeConverter.CouldBeExpressionTree(invocationExpression)) {
				Expression converted = ExpressionTreeConverter.TryConvert(context, invocationExpression);
				if (converted != null) {
					invocationExpression.ReplaceWith(converted);
					return converted.AcceptVisitor(this, data);
				}
			}
			return base.VisitInvocationExpression(invocationExpression, data);
		}
		
		#region Track current variables
		public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data)
		{
			Debug.Assert(currentlyUsedVariableNames.Count == 0);
			try {
				currentlyUsedVariableNames.AddRange(methodDeclaration.Parameters.Select(p => p.Name));
				return base.VisitMethodDeclaration(methodDeclaration, data);
			} finally {
				currentlyUsedVariableNames.Clear();
			}
		}
		
		public override object VisitOperatorDeclaration(OperatorDeclaration operatorDeclaration, object data)
		{
			Debug.Assert(currentlyUsedVariableNames.Count == 0);
			try {
				currentlyUsedVariableNames.AddRange(operatorDeclaration.Parameters.Select(p => p.Name));
				return base.VisitOperatorDeclaration(operatorDeclaration, data);
			} finally {
				currentlyUsedVariableNames.Clear();
			}
		}
		
		public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data)
		{
			Debug.Assert(currentlyUsedVariableNames.Count == 0);
			try {
				currentlyUsedVariableNames.AddRange(constructorDeclaration.Parameters.Select(p => p.Name));
				return base.VisitConstructorDeclaration(constructorDeclaration, data);
			} finally {
				currentlyUsedVariableNames.Clear();
			}
		}
		
		public override object VisitIndexerDeclaration(IndexerDeclaration indexerDeclaration, object data)
		{
			Debug.Assert(currentlyUsedVariableNames.Count == 0);
			try {
				currentlyUsedVariableNames.AddRange(indexerDeclaration.Parameters.Select(p => p.Name));
				return base.VisitIndexerDeclaration(indexerDeclaration, data);
			} finally {
				currentlyUsedVariableNames.Clear();
			}
		}
		
		public override object VisitAccessor(Accessor accessor, object data)
		{
			try {
				currentlyUsedVariableNames.Add("value");
				return base.VisitAccessor(accessor, data);
			} finally {
				currentlyUsedVariableNames.RemoveAt(currentlyUsedVariableNames.Count - 1);
			}
		}
		
		public override object VisitVariableDeclarationStatement(VariableDeclarationStatement variableDeclarationStatement, object data)
		{
			foreach (VariableInitializer v in variableDeclarationStatement.Variables)
				currentlyUsedVariableNames.Add(v.Name);
			return base.VisitVariableDeclarationStatement(variableDeclarationStatement, data);
		}
		
		public override object VisitFixedStatement(FixedStatement fixedStatement, object data)
		{
			foreach (VariableInitializer v in fixedStatement.Variables)
				currentlyUsedVariableNames.Add(v.Name);
			return base.VisitFixedStatement(fixedStatement, data);
		}
		#endregion
		
		static readonly ExpressionStatement displayClassAssignmentPattern =
			new ExpressionStatement(new AssignmentExpression(
				new NamedNode("variable", new IdentifierExpression(Pattern.AnyString)),
				new ObjectCreateExpression { Type = new AnyNode("type") }
			));
		
		public override object VisitBlockStatement(BlockStatement blockStatement, object data)
		{
			int numberOfVariablesOutsideBlock = currentlyUsedVariableNames.Count;
			base.VisitBlockStatement(blockStatement, data);
			foreach (ExpressionStatement stmt in blockStatement.Statements.OfType<ExpressionStatement>().ToArray()) {
				Match displayClassAssignmentMatch = displayClassAssignmentPattern.Match(stmt);
				if (!displayClassAssignmentMatch.Success)
					continue;
				
				ILVariable variable = displayClassAssignmentMatch.Get<AstNode>("variable").Single().Annotation<ILVariable>();
				if (variable == null)
					continue;
				TypeDefinition type = variable.Type.ResolveWithinSameModule();
				if (!IsPotentialClosure(context, type))
					continue;
				if (displayClassAssignmentMatch.Get<AstType>("type").Single().Annotation<TypeReference>().ResolveWithinSameModule() != type)
					continue;
				
				// Looks like we found a display class creation. Now let's verify that the variable is used only for field accesses:
				bool ok = true;
				foreach (var identExpr in blockStatement.Descendants.OfType<IdentifierExpression>()) {
					if (identExpr.Identifier == variable.Name && identExpr != displayClassAssignmentMatch.Get("variable").Single()) {
						if (!(identExpr.Parent is MemberReferenceExpression && identExpr.Parent.Annotation<FieldReference>() != null))
							ok = false;
					}
				}
				if (!ok)
					continue;
				Dictionary<FieldReference, AstNode> dict = new Dictionary<FieldReference, AstNode>();
				
				// Delete the variable declaration statement:
				VariableDeclarationStatement displayClassVarDecl = PatternStatementTransform.FindVariableDeclaration(stmt, variable.Name);
				if (displayClassVarDecl != null)
					displayClassVarDecl.Remove();
				
				// Delete the assignment statement:
				AstNode cur = stmt.NextSibling;
				stmt.Remove();
				
				// Delete any following statements as long as they assign parameters to the display class
				BlockStatement rootBlock = blockStatement.Ancestors.OfType<BlockStatement>().LastOrDefault() ?? blockStatement;
				List<ILVariable> parameterOccurrances = rootBlock.Descendants.OfType<IdentifierExpression>()
					.Select(n => n.Annotation<ILVariable>()).Where(p => p != null && p.IsParameter).ToList();
				AstNode next;
				for (; cur != null; cur = next) {
					next = cur.NextSibling;
					
					// Test for the pattern:
					// "variableName.MemberName = right;"
					ExpressionStatement closureFieldAssignmentPattern = new ExpressionStatement(
						new AssignmentExpression(
							new NamedNode("left", new MemberReferenceExpression { 
							              	Target = new IdentifierExpression(variable.Name),
							              	MemberName = Pattern.AnyString
							              }),
							new AnyNode("right")
						)
					);
					Match m = closureFieldAssignmentPattern.Match(cur);
					if (m.Success) {
						FieldDefinition fieldDef = m.Get<MemberReferenceExpression>("left").Single().Annotation<FieldReference>().ResolveWithinSameModule();
						AstNode right = m.Get<AstNode>("right").Single();
						bool isParameter = false;
						bool isDisplayClassParentPointerAssignment = false;
						if (right is ThisReferenceExpression) {
							isParameter = true;
						} else if (right is IdentifierExpression) {
							// handle parameters only if the whole method contains no other occurrence except for 'right'
							ILVariable v = right.Annotation<ILVariable>();
							isParameter = v.IsParameter && parameterOccurrances.Count(c => c == v) == 1;
							if (!isParameter && IsPotentialClosure(context, v.Type.ResolveWithinSameModule())) {
								// parent display class within the same method
								// (closure2.localsX = closure1;)
								isDisplayClassParentPointerAssignment = true;
							}
						} else if (right is MemberReferenceExpression) {
							// copy of parent display class reference from an outer lambda
							// closure2.localsX = this.localsY
							MemberReferenceExpression mre = m.Get<MemberReferenceExpression>("right").Single();
							do {
								// descend into the targets of the mre as long as the field types are closures
								FieldDefinition fieldDef2 = mre.Annotation<FieldReference>().ResolveWithinSameModule();
								if (fieldDef2 == null || !IsPotentialClosure(context, fieldDef2.FieldType.ResolveWithinSameModule())) {
									break;
								}
								// if we finally get to a this reference, it's copying a display class parent pointer
								if (mre.Target is ThisReferenceExpression) {
									isDisplayClassParentPointerAssignment = true;
								}
								mre = mre.Target as MemberReferenceExpression;
							} while (mre != null);
						}
						if (isParameter || isDisplayClassParentPointerAssignment) {
							dict[fieldDef] = right;
							cur.Remove();
						} else {
							break;
						}
					} else {
						break;
					}
				}
				
				// Now create variables for all fields of the display class (except for those that we already handled as parameters)
				List<Tuple<AstType, ILVariable>> variablesToDeclare = new List<Tuple<AstType, ILVariable>>();
				foreach (FieldDefinition field in type.Fields) {
					if (field.IsStatic)
						continue; // skip static fields
					if (dict.ContainsKey(field)) // skip field if it already was handled as parameter
						continue;
					string capturedVariableName = field.Name;
					if (capturedVariableName.StartsWith("$VB$Local_", StringComparison.Ordinal) && capturedVariableName.Length > 10)
						capturedVariableName = capturedVariableName.Substring(10);
					EnsureVariableNameIsAvailable(blockStatement, capturedVariableName);
					currentlyUsedVariableNames.Add(capturedVariableName);
					ILVariable ilVar = new ILVariable
					{
						IsGenerated = true,
						Name = capturedVariableName,
						Type = field.FieldType,
					};
					variablesToDeclare.Add(Tuple.Create(AstBuilder.ConvertType(field.FieldType, field), ilVar));
					dict[field] = new IdentifierExpression(capturedVariableName).WithAnnotation(ilVar);
				}
				
				// Now figure out where the closure was accessed and use the simpler replacement expression there:
				foreach (var identExpr in blockStatement.Descendants.OfType<IdentifierExpression>()) {
					if (identExpr.Identifier == variable.Name) {
						MemberReferenceExpression mre = (MemberReferenceExpression)identExpr.Parent;
						AstNode replacement;
						if (dict.TryGetValue(mre.Annotation<FieldReference>().ResolveWithinSameModule(), out replacement)) {
							mre.ReplaceWith(replacement.Clone());
						}
					}
				}
				// Now insert the variable declarations (we can do this after the replacements only so that the scope detection works):
				Statement insertionPoint = blockStatement.Statements.FirstOrDefault();
				foreach (var tuple in variablesToDeclare) {
					var newVarDecl = new VariableDeclarationStatement(tuple.Item1, tuple.Item2.Name);
					newVarDecl.Variables.Single().AddAnnotation(new CapturedVariableAnnotation());
					newVarDecl.Variables.Single().AddAnnotation(tuple.Item2);
					blockStatement.Statements.InsertBefore(insertionPoint, newVarDecl);
				}
			}
			currentlyUsedVariableNames.RemoveRange(numberOfVariablesOutsideBlock, currentlyUsedVariableNames.Count - numberOfVariablesOutsideBlock);
			return null;
		}

		void EnsureVariableNameIsAvailable(AstNode currentNode, string name)
		{
			int pos = currentlyUsedVariableNames.IndexOf(name);
			if (pos < 0) {
				// name is still available
				return;
			}
			// Naming conflict. Let's rename the existing variable so that the field keeps the name from metadata.
			NameVariables nv = new NameVariables();
			// Add currently used variable and parameter names
			foreach (string nameInUse in currentlyUsedVariableNames)
				nv.AddExistingName(nameInUse);
			// variables declared in child nodes of this block
			foreach (VariableInitializer vi in currentNode.Descendants.OfType<VariableInitializer>())
				nv.AddExistingName(vi.Name);
			// parameters in child lambdas
			foreach (ParameterDeclaration pd in currentNode.Descendants.OfType<ParameterDeclaration>())
				nv.AddExistingName(pd.Name);
			
			string newName = nv.GetAlternativeName(name);
			currentlyUsedVariableNames[pos] = newName;
			
			// find top-most block
			AstNode topMostBlock = currentNode.Ancestors.OfType<BlockStatement>().LastOrDefault() ?? currentNode;
			
			// rename identifiers
			foreach (IdentifierExpression ident in topMostBlock.Descendants.OfType<IdentifierExpression>()) {
				if (ident.Identifier == name) {
					ident.Identifier = newName;
					ILVariable v = ident.Annotation<ILVariable>();
					if (v != null)
						v.Name = newName;
				}
			}
			// rename variable declarations
			foreach (VariableInitializer vi in topMostBlock.Descendants.OfType<VariableInitializer>()) {
				if (vi.Name == name) {
					vi.Name = newName;
					ILVariable v = vi.Annotation<ILVariable>();
					if (v != null)
						v.Name = newName;
				}
			}
		}
	}
}