diff options
Diffstat (limited to 'ICSharpCode.Decompiler/Ast/Transforms')
18 files changed, 5314 insertions, 0 deletions
diff --git a/ICSharpCode.Decompiler/Ast/Transforms/AddCheckedBlocks.cs b/ICSharpCode.Decompiler/Ast/Transforms/AddCheckedBlocks.cs new file mode 100644 index 00000000..dc018eb1 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/AddCheckedBlocks.cs @@ -0,0 +1,368 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Linq; +using ICSharpCode.Decompiler.ILAst; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Add checked/unchecked blocks. + /// </summary> + public class AddCheckedBlocks : IAstTransform + { + #region Annotation + sealed class CheckedUncheckedAnnotation { + /// <summary> + /// true=checked, false=unchecked + /// </summary> + public bool IsChecked; + } + + public static readonly object CheckedAnnotation = new CheckedUncheckedAnnotation { IsChecked = true }; + public static readonly object UncheckedAnnotation = new CheckedUncheckedAnnotation { IsChecked = false }; + #endregion + + /* + We treat placing checked/unchecked blocks as an optimization problem, with the following goals: + 1. Use minimum number of checked blocks+expressions + 2. Prefer checked expressions over checked blocks + 3. Make the scope of checked expressions as small as possible + 4. Open checked blocks as late as possible, and close checked blocks as late as possible + (where goal 1 has the highest priority) + + Goal 4a (open checked blocks as late as possible) is necessary so that we don't move variable declarations + into checked blocks, as the variable might still be used after the checked block. + (this could cause DeclareVariables to omit the variable declaration, producing incorrect code) + Goal 4b (close checked blocks as late as possible) makes the code look nicer in this case: + checked { + int c = a + b; + int r = a + c; + return r; + } + If the checked block was closed as early as possible, the variable r would have to be declared outside + (this would work, but look badly) + */ + + #region struct Cost + struct Cost + { + // highest possible cost so that the Blocks+Expressions addition doesn't overflow + public static readonly Cost Infinite = new Cost(0x3fffffff, 0x3fffffff); + + public readonly int Blocks; + public readonly int Expressions; + + public Cost(int blocks, int expressions) + { + Blocks = blocks; + Expressions = expressions; + } + + public static bool operator <(Cost a, Cost b) + { + return a.Blocks + a.Expressions < b.Blocks + b.Expressions + || a.Blocks + a.Expressions == b.Blocks + b.Expressions && a.Blocks < b.Blocks; + } + + public static bool operator >(Cost a, Cost b) + { + return a.Blocks + a.Expressions > b.Blocks + b.Expressions + || a.Blocks + a.Expressions == b.Blocks + b.Expressions && a.Blocks > b.Blocks; + } + + public static bool operator <=(Cost a, Cost b) + { + return a.Blocks + a.Expressions < b.Blocks + b.Expressions + || a.Blocks + a.Expressions == b.Blocks + b.Expressions && a.Blocks <= b.Blocks; + } + + public static bool operator >=(Cost a, Cost b) + { + return a.Blocks + a.Expressions > b.Blocks + b.Expressions + || a.Blocks + a.Expressions == b.Blocks + b.Expressions && a.Blocks >= b.Blocks; + } + + public static Cost operator +(Cost a, Cost b) + { + return new Cost(a.Blocks + b.Blocks, a.Expressions + b.Expressions); + } + + public override string ToString() + { + return string.Format("[{0} + {1}]", Blocks, Expressions); + } + } + #endregion + + #region class InsertedNode + /// <summary> + /// Holds the blocks and expressions that should be inserted + /// </summary> + abstract class InsertedNode + { + public static InsertedNode operator +(InsertedNode a, InsertedNode b) + { + if (a == null) + return b; + if (b == null) + return a; + return new InsertedNodeList(a, b); + } + + public abstract void Insert(); + } + + class InsertedNodeList : InsertedNode + { + readonly InsertedNode child1, child2; + + public InsertedNodeList(InsertedNode child1, InsertedNode child2) + { + this.child1 = child1; + this.child2 = child2; + } + + public override void Insert() + { + child1.Insert(); + child2.Insert(); + } + } + + class InsertedExpression : InsertedNode + { + readonly Expression expression; + readonly bool isChecked; + + public InsertedExpression(Expression expression, bool isChecked) + { + this.expression = expression; + this.isChecked = isChecked; + } + + public override void Insert() + { + if (isChecked) + expression.ReplaceWith(e => new CheckedExpression { Expression = e }); + else + expression.ReplaceWith(e => new UncheckedExpression { Expression = e }); + } + } + + class ConvertCompoundAssignment : InsertedNode + { + readonly Expression expression; + readonly bool isChecked; + + public ConvertCompoundAssignment(Expression expression, bool isChecked) + { + this.expression = expression; + this.isChecked = isChecked; + } + + public override void Insert() + { + AssignmentExpression assign = expression.Annotation<ReplaceMethodCallsWithOperators.RestoreOriginalAssignOperatorAnnotation>().Restore(expression); + expression.ReplaceWith(assign); + if (isChecked) + assign.Right = new CheckedExpression { Expression = assign.Right.Detach() }; + else + assign.Right = new UncheckedExpression { Expression = assign.Right.Detach() }; + } + } + + class InsertedBlock : InsertedNode + { + readonly Statement firstStatement; // inclusive + readonly Statement lastStatement; // exclusive + readonly bool isChecked; + + public InsertedBlock(Statement firstStatement, Statement lastStatement, bool isChecked) + { + this.firstStatement = firstStatement; + this.lastStatement = lastStatement; + this.isChecked = isChecked; + } + + public override void Insert() + { + BlockStatement newBlock = new BlockStatement(); + // Move all statements except for the first + Statement next; + for (Statement stmt = firstStatement.GetNextStatement(); stmt != lastStatement; stmt = next) { + next = stmt.GetNextStatement(); + newBlock.Add(stmt.Detach()); + } + // Replace the first statement with the new (un)checked block + if (isChecked) + firstStatement.ReplaceWith(new CheckedStatement { Body = newBlock }); + else + firstStatement.ReplaceWith(new UncheckedStatement { Body = newBlock }); + // now also move the first node into the new block + newBlock.Statements.InsertAfter(null, firstStatement); + } + } + #endregion + + #region class Result + /// <summary> + /// Holds the result of an insertion operation. + /// </summary> + class Result + { + public Cost CostInCheckedContext; + public InsertedNode NodesToInsertInCheckedContext; + public Cost CostInUncheckedContext; + public InsertedNode NodesToInsertInUncheckedContext; + } + #endregion + + public void Run(AstNode node) + { + BlockStatement block = node as BlockStatement; + if (block == null) { + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + Run(child); + } + } else { + Result r = GetResultFromBlock(block); + if (r.NodesToInsertInUncheckedContext != null) + r.NodesToInsertInUncheckedContext.Insert(); + } + } + + Result GetResultFromBlock(BlockStatement block) + { + // For a block, we are tracking 4 possibilities: + // a) context is checked, no unchecked block open + Cost costCheckedContext = new Cost(0, 0); + InsertedNode nodesCheckedContext = null; + // b) context is checked, an unchecked block is open + Cost costCheckedContextUncheckedBlockOpen = Cost.Infinite; + InsertedNode nodesCheckedContextUncheckedBlockOpen = null; + Statement uncheckedBlockStart = null; + // c) context is unchecked, no checked block open + Cost costUncheckedContext = new Cost(0, 0); + InsertedNode nodesUncheckedContext = null; + // d) context is unchecked, a checked block is open + Cost costUncheckedContextCheckedBlockOpen = Cost.Infinite; + InsertedNode nodesUncheckedContextCheckedBlockOpen = null; + Statement checkedBlockStart = null; + + Statement statement = block.Statements.FirstOrDefault(); + while (true) { + // Blocks can be closed 'for free'. We use '<=' so that blocks are closed as late as possible (goal 4b) + if (costCheckedContextUncheckedBlockOpen <= costCheckedContext) { + costCheckedContext = costCheckedContextUncheckedBlockOpen; + nodesCheckedContext = nodesCheckedContextUncheckedBlockOpen + new InsertedBlock(uncheckedBlockStart, statement, false); + } + if (costUncheckedContextCheckedBlockOpen <= costUncheckedContext) { + costUncheckedContext = costUncheckedContextCheckedBlockOpen; + nodesUncheckedContext = nodesUncheckedContextCheckedBlockOpen + new InsertedBlock(checkedBlockStart, statement, true); + } + if (statement == null) + break; + // Now try opening blocks. We use '<=' so that blocks are opened as late as possible. (goal 4a) + if (costCheckedContext + new Cost(1, 0) <= costCheckedContextUncheckedBlockOpen) { + costCheckedContextUncheckedBlockOpen = costCheckedContext + new Cost(1, 0); + nodesCheckedContextUncheckedBlockOpen = nodesCheckedContext; + uncheckedBlockStart = statement; + } + if (costUncheckedContext + new Cost(1, 0) <= costUncheckedContextCheckedBlockOpen) { + costUncheckedContextCheckedBlockOpen = costUncheckedContext + new Cost(1, 0); + nodesUncheckedContextCheckedBlockOpen = nodesUncheckedContext; + checkedBlockStart = statement; + } + // Now handle the statement + Result stmtResult = GetResult(statement); + + costCheckedContext += stmtResult.CostInCheckedContext; + nodesCheckedContext += stmtResult.NodesToInsertInCheckedContext; + costCheckedContextUncheckedBlockOpen += stmtResult.CostInUncheckedContext; + nodesCheckedContextUncheckedBlockOpen += stmtResult.NodesToInsertInUncheckedContext; + costUncheckedContext += stmtResult.CostInUncheckedContext; + nodesUncheckedContext += stmtResult.NodesToInsertInUncheckedContext; + costUncheckedContextCheckedBlockOpen += stmtResult.CostInCheckedContext; + nodesUncheckedContextCheckedBlockOpen += stmtResult.NodesToInsertInCheckedContext; + + statement = statement.GetNextStatement(); + } + + return new Result { + CostInCheckedContext = costCheckedContext, NodesToInsertInCheckedContext = nodesCheckedContext, + CostInUncheckedContext = costUncheckedContext, NodesToInsertInUncheckedContext = nodesUncheckedContext + }; + } + + Result GetResult(AstNode node) + { + if (node is BlockStatement) + return GetResultFromBlock((BlockStatement)node); + Result result = new Result(); + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + Result childResult = GetResult(child); + result.CostInCheckedContext += childResult.CostInCheckedContext; + result.NodesToInsertInCheckedContext += childResult.NodesToInsertInCheckedContext; + result.CostInUncheckedContext += childResult.CostInUncheckedContext; + result.NodesToInsertInUncheckedContext += childResult.NodesToInsertInUncheckedContext; + } + Expression expr = node as Expression; + if (expr != null) { + CheckedUncheckedAnnotation annotation = expr.Annotation<CheckedUncheckedAnnotation>(); + if (annotation != null) { + // If the annotation requires this node to be in a specific context, add a huge cost to the other context + // That huge cost gives us the option to ignore a required checked/unchecked expression when there wouldn't be any + // solution otherwise. (e.g. "for (checked(M().x += 1); true; unchecked(M().x += 2)) {}") + if (annotation.IsChecked) + result.CostInUncheckedContext += new Cost(10000, 0); + else + result.CostInCheckedContext += new Cost(10000, 0); + } + // Embed this node in an checked/unchecked expression: + if (expr.Parent is ExpressionStatement) { + // We cannot use checked/unchecked for top-level-expressions. + // However, we could try converting a compound assignment (checked(a+=b);) or unary operator (checked(a++);) + // back to its old form. + if (expr.Annotation<ReplaceMethodCallsWithOperators.RestoreOriginalAssignOperatorAnnotation>() != null) { + // We use '<' so that expressions are introduced on the deepest level possible (goal 3) + if (result.CostInCheckedContext + new Cost(1, 1) < result.CostInUncheckedContext) { + result.CostInUncheckedContext = result.CostInCheckedContext + new Cost(1, 1); + result.NodesToInsertInUncheckedContext = result.NodesToInsertInCheckedContext + new ConvertCompoundAssignment(expr, true); + } else if (result.CostInUncheckedContext + new Cost(1, 1) < result.CostInCheckedContext) { + result.CostInCheckedContext = result.CostInUncheckedContext + new Cost(1, 1); + result.NodesToInsertInCheckedContext = result.NodesToInsertInUncheckedContext + new ConvertCompoundAssignment(expr, false); + } + } + } else if (expr.Role.IsValid(Expression.Null)) { + // We use '<' so that expressions are introduced on the deepest level possible (goal 3) + if (result.CostInCheckedContext + new Cost(0, 1) < result.CostInUncheckedContext) { + result.CostInUncheckedContext = result.CostInCheckedContext + new Cost(0, 1); + result.NodesToInsertInUncheckedContext = result.NodesToInsertInCheckedContext + new InsertedExpression(expr, true); + } else if (result.CostInUncheckedContext + new Cost(0, 1) < result.CostInCheckedContext) { + result.CostInCheckedContext = result.CostInUncheckedContext + new Cost(0, 1); + result.NodesToInsertInCheckedContext = result.NodesToInsertInUncheckedContext + new InsertedExpression(expr, false); + } + } + } + return result; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/CombineQueryExpressions.cs b/ICSharpCode.Decompiler/Ast/Transforms/CombineQueryExpressions.cs new file mode 100644 index 00000000..a0d1ca8c --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/CombineQueryExpressions.cs @@ -0,0 +1,179 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Combines query expressions and removes transparent identifiers. + /// </summary> + public class CombineQueryExpressions : IAstTransform + { + readonly DecompilerContext context; + + public CombineQueryExpressions(DecompilerContext context) + { + this.context = context; + } + + public void Run(AstNode compilationUnit) + { + if (!context.Settings.QueryExpressions) + return; + CombineQueries(compilationUnit); + } + + static readonly InvocationExpression castPattern = new InvocationExpression { + Target = new MemberReferenceExpression { + Target = new AnyNode("inExpr"), + MemberName = "Cast", + TypeArguments = { new AnyNode("targetType") } + }}; + + void CombineQueries(AstNode node) + { + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + CombineQueries(child); + } + QueryExpression query = node as QueryExpression; + if (query != null) { + QueryFromClause fromClause = (QueryFromClause)query.Clauses.First(); + QueryExpression innerQuery = fromClause.Expression as QueryExpression; + if (innerQuery != null) { + if (TryRemoveTransparentIdentifier(query, fromClause, innerQuery)) { + RemoveTransparentIdentifierReferences(query); + } else { + QueryContinuationClause continuation = new QueryContinuationClause(); + continuation.PrecedingQuery = innerQuery.Detach(); + continuation.Identifier = fromClause.Identifier; + fromClause.ReplaceWith(continuation); + } + } else { + Match m = castPattern.Match(fromClause.Expression); + if (m.Success) { + fromClause.Type = m.Get<AstType>("targetType").Single().Detach(); + fromClause.Expression = m.Get<Expression>("inExpr").Single().Detach(); + } + } + } + } + + static readonly QuerySelectClause selectTransparentIdentifierPattern = new QuerySelectClause { + Expression = new Choice { + new AnonymousTypeCreateExpression { + Initializers = { + new NamedExpression { + Name = Pattern.AnyString, + Expression = new IdentifierExpression(Pattern.AnyString) + }.WithName("nae1"), + new NamedExpression { + Name = Pattern.AnyString, + Expression = new AnyNode("nae2Expr") + }.WithName("nae2") + } + }, + new AnonymousTypeCreateExpression { + Initializers = { + new NamedNode("identifier", new IdentifierExpression(Pattern.AnyString)), + new AnyNode("nae2Expr") + } + } + }}; + + bool IsTransparentIdentifier(string identifier) + { + return identifier.StartsWith("<>", StringComparison.Ordinal) && identifier.Contains("TransparentIdentifier"); + } + + bool TryRemoveTransparentIdentifier(QueryExpression query, QueryFromClause fromClause, QueryExpression innerQuery) + { + if (!IsTransparentIdentifier(fromClause.Identifier)) + return false; + Match match = selectTransparentIdentifierPattern.Match(innerQuery.Clauses.Last()); + if (!match.Success) + return false; + QuerySelectClause selectClause = (QuerySelectClause)innerQuery.Clauses.Last(); + NamedExpression nae1 = match.Get<NamedExpression>("nae1").SingleOrDefault(); + NamedExpression nae2 = match.Get<NamedExpression>("nae2").SingleOrDefault(); + if (nae1 != null && nae1.Name != ((IdentifierExpression)nae1.Expression).Identifier) + return false; + Expression nae2Expr = match.Get<Expression>("nae2Expr").Single(); + IdentifierExpression nae2IdentExpr = nae2Expr as IdentifierExpression; + if (nae2IdentExpr != null && (nae2 == null || nae2.Name == nae2IdentExpr.Identifier)) { + // from * in (from x in ... select new { x = x, y = y }) ... + // => + // from x in ... ... + fromClause.Remove(); + selectClause.Remove(); + // Move clauses from innerQuery to query + QueryClause insertionPos = null; + foreach (var clause in innerQuery.Clauses) { + query.Clauses.InsertAfter(insertionPos, insertionPos = clause.Detach()); + } + } else { + // from * in (from x in ... select new { x = x, y = expr }) ... + // => + // from x in ... let y = expr ... + fromClause.Remove(); + selectClause.Remove(); + // Move clauses from innerQuery to query + QueryClause insertionPos = null; + foreach (var clause in innerQuery.Clauses) { + query.Clauses.InsertAfter(insertionPos, insertionPos = clause.Detach()); + } + string ident; + if (nae2 != null) + ident = nae2.Name; + else if (nae2Expr is IdentifierExpression) + ident = ((IdentifierExpression)nae2Expr).Identifier; + else if (nae2Expr is MemberReferenceExpression) + ident = ((MemberReferenceExpression)nae2Expr).MemberName; + else + throw new InvalidOperationException("Could not infer name from initializer in AnonymousTypeCreateExpression"); + query.Clauses.InsertAfter(insertionPos, new QueryLetClause { Identifier = ident, Expression = nae2Expr.Detach() }); + } + return true; + } + + /// <summary> + /// Removes all occurrences of transparent identifiers + /// </summary> + void RemoveTransparentIdentifierReferences(AstNode node) + { + foreach (AstNode child in node.Children) { + RemoveTransparentIdentifierReferences(child); + } + MemberReferenceExpression mre = node as MemberReferenceExpression; + if (mre != null) { + IdentifierExpression ident = mre.Target as IdentifierExpression; + if (ident != null && IsTransparentIdentifier(ident.Identifier)) { + IdentifierExpression newIdent = new IdentifierExpression(mre.MemberName); + mre.TypeArguments.MoveTo(newIdent.TypeArguments); + newIdent.CopyAnnotationsFrom(mre); + newIdent.RemoveAnnotations<PropertyDeclaration>(); // remove the reference to the property of the anonymous type + mre.ReplaceWith(newIdent); + return; + } + } + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs b/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs new file mode 100644 index 00000000..1d1f3d08 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs @@ -0,0 +1,111 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Diagnostics; +using ICSharpCode.NRefactory.CSharp; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Base class for AST visitors that need the current type/method context info. + /// </summary> + public abstract class ContextTrackingVisitor<TResult> : DepthFirstAstVisitor<object, TResult>, IAstTransform + { + protected readonly DecompilerContext context; + + protected ContextTrackingVisitor(DecompilerContext context) + { + if (context == null) + throw new ArgumentNullException("context"); + this.context = context; + } + + public override TResult VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data) + { + TypeDefinition oldType = context.CurrentType; + try { + context.CurrentType = typeDeclaration.Annotation<TypeDefinition>(); + return base.VisitTypeDeclaration(typeDeclaration, data); + } finally { + context.CurrentType = oldType; + } + } + + public override TResult VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = methodDeclaration.Annotation<MethodDefinition>(); + return base.VisitMethodDeclaration(methodDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = constructorDeclaration.Annotation<MethodDefinition>(); + return base.VisitConstructorDeclaration(constructorDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitDestructorDeclaration(DestructorDeclaration destructorDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = destructorDeclaration.Annotation<MethodDefinition>(); + return base.VisitDestructorDeclaration(destructorDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitOperatorDeclaration(OperatorDeclaration operatorDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = operatorDeclaration.Annotation<MethodDefinition>(); + return base.VisitOperatorDeclaration(operatorDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitAccessor(Accessor accessor, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = accessor.Annotation<MethodDefinition>(); + return base.VisitAccessor(accessor, data); + } finally { + context.CurrentMethod = null; + } + } + + void IAstTransform.Run(AstNode node) + { + node.AcceptVisitor(this, null); + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ConvertConstructorCallIntoInitializer.cs b/ICSharpCode.Decompiler/Ast/Transforms/ConvertConstructorCallIntoInitializer.cs new file mode 100644 index 00000000..36811c18 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/ConvertConstructorCallIntoInitializer.cs @@ -0,0 +1,183 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// If the first element of a constructor is a chained constructor call, convert it into a constructor initializer. + /// </summary> + public class ConvertConstructorCallIntoInitializer : DepthFirstAstVisitor<object, object>, IAstTransform + { + public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) + { + ExpressionStatement stmt = constructorDeclaration.Body.Statements.FirstOrDefault() as ExpressionStatement; + if (stmt == null) + return null; + InvocationExpression invocation = stmt.Expression as InvocationExpression; + if (invocation == null) + return null; + MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression; + if (mre != null && mre.MemberName == ".ctor") { + ConstructorInitializer ci = new ConstructorInitializer(); + if (mre.Target is ThisReferenceExpression) + ci.ConstructorInitializerType = ConstructorInitializerType.This; + else if (mre.Target is BaseReferenceExpression) + ci.ConstructorInitializerType = ConstructorInitializerType.Base; + else + return null; + // Move arguments from invocation to initializer: + invocation.Arguments.MoveTo(ci.Arguments); + // Add the initializer: (unless it is the default 'base()') + if (!(ci.ConstructorInitializerType == ConstructorInitializerType.Base && ci.Arguments.Count == 0)) + constructorDeclaration.Initializer = ci.WithAnnotation(invocation.Annotation<MethodReference>()); + // Remove the statement: + stmt.Remove(); + } + return null; + } + + static readonly ExpressionStatement fieldInitializerPattern = new ExpressionStatement { + Expression = new AssignmentExpression { + Left = new NamedNode("fieldAccess", new MemberReferenceExpression { + Target = new ThisReferenceExpression(), + MemberName = Pattern.AnyString + }), + Operator = AssignmentOperatorType.Assign, + Right = new AnyNode("initializer") + } + }; + + static readonly AstNode thisCallPattern = new ExpressionStatement(new ThisReferenceExpression().Invoke(".ctor", new Repeat(new AnyNode()))); + + public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data) + { + // Handle initializers on instance fields + HandleInstanceFieldInitializers(typeDeclaration.Members); + + // Now convert base constructor calls to initializers: + base.VisitTypeDeclaration(typeDeclaration, data); + + // Remove single empty constructor: + RemoveSingleEmptyConstructor(typeDeclaration); + + // Handle initializers on static fields: + HandleStaticFieldInitializers(typeDeclaration.Members); + return null; + } + + void HandleInstanceFieldInitializers(IEnumerable<AstNode> members) + { + var instanceCtors = members.OfType<ConstructorDeclaration>().Where(c => (c.Modifiers & Modifiers.Static) == 0).ToArray(); + var instanceCtorsNotChainingWithThis = instanceCtors.Where(ctor => !thisCallPattern.IsMatch(ctor.Body.Statements.FirstOrDefault())).ToArray(); + if (instanceCtorsNotChainingWithThis.Length > 0) { + MethodDefinition ctorMethodDef = instanceCtorsNotChainingWithThis[0].Annotation<MethodDefinition>(); + if (ctorMethodDef != null && ctorMethodDef.DeclaringType.IsValueType) + return; + + // Recognize field initializers: + // Convert first statement in all ctors (if all ctors have the same statement) into a field initializer. + bool allSame; + do { + Match m = fieldInitializerPattern.Match(instanceCtorsNotChainingWithThis[0].Body.FirstOrDefault()); + if (!m.Success) + break; + + FieldDefinition fieldDef = m.Get<AstNode>("fieldAccess").Single().Annotation<FieldReference>().ResolveWithinSameModule(); + if (fieldDef == null) + break; + AstNode fieldOrEventDecl = members.FirstOrDefault(f => f.Annotation<FieldDefinition>() == fieldDef); + if (fieldOrEventDecl == null) + break; + Expression initializer = m.Get<Expression>("initializer").Single(); + // 'this'/'base' cannot be used in field initializers + if (initializer.DescendantsAndSelf.Any(n => n is ThisReferenceExpression || n is BaseReferenceExpression)) + break; + + allSame = true; + for (int i = 1; i < instanceCtorsNotChainingWithThis.Length; i++) { + if (!instanceCtors[0].Body.First().IsMatch(instanceCtorsNotChainingWithThis[i].Body.FirstOrDefault())) + allSame = false; + } + if (allSame) { + foreach (var ctor in instanceCtorsNotChainingWithThis) + ctor.Body.First().Remove(); + fieldOrEventDecl.GetChildrenByRole(Roles.Variable).Single().Initializer = initializer.Detach(); + } + } while (allSame); + } + } + + void RemoveSingleEmptyConstructor(TypeDeclaration typeDeclaration) + { + var instanceCtors = typeDeclaration.Members.OfType<ConstructorDeclaration>().Where(c => (c.Modifiers & Modifiers.Static) == 0).ToArray(); + if (instanceCtors.Length == 1) { + ConstructorDeclaration emptyCtor = new ConstructorDeclaration(); + emptyCtor.Modifiers = ((typeDeclaration.Modifiers & Modifiers.Abstract) == Modifiers.Abstract ? Modifiers.Protected : Modifiers.Public); + emptyCtor.Body = new BlockStatement(); + if (emptyCtor.IsMatch(instanceCtors[0])) + instanceCtors[0].Remove(); + } + } + + void HandleStaticFieldInitializers(IEnumerable<AstNode> members) + { + // Convert static constructor into field initializers if the class is BeforeFieldInit + var staticCtor = members.OfType<ConstructorDeclaration>().FirstOrDefault(c => (c.Modifiers & Modifiers.Static) == Modifiers.Static); + if (staticCtor != null) { + MethodDefinition ctorMethodDef = staticCtor.Annotation<MethodDefinition>(); + if (ctorMethodDef != null && ctorMethodDef.DeclaringType.IsBeforeFieldInit) { + while (true) { + ExpressionStatement es = staticCtor.Body.Statements.FirstOrDefault() as ExpressionStatement; + if (es == null) + break; + AssignmentExpression assignment = es.Expression as AssignmentExpression; + if (assignment == null || assignment.Operator != AssignmentOperatorType.Assign) + break; + FieldDefinition fieldDef = assignment.Left.Annotation<FieldReference>().ResolveWithinSameModule(); + if (fieldDef == null || !fieldDef.IsStatic) + break; + FieldDeclaration fieldDecl = members.OfType<FieldDeclaration>().FirstOrDefault(f => f.Annotation<FieldDefinition>() == fieldDef); + if (fieldDecl == null) + break; + fieldDecl.Variables.Single().Initializer = assignment.Right.Detach(); + es.Remove(); + } + if (staticCtor.Body.Statements.Count == 0) + staticCtor.Remove(); + } + } + } + + void IAstTransform.Run(AstNode node) + { + // If we're viewing some set of members (fields are direct children of CompilationUnit), + // we also need to handle those: + HandleInstanceFieldInitializers(node.Children); + HandleStaticFieldInitializers(node.Children); + + node.AcceptVisitor(this, null); + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/CustomPatterns.cs b/ICSharpCode.Decompiler/Ast/Transforms/CustomPatterns.cs new file mode 100644 index 00000000..b80c56af --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/CustomPatterns.cs @@ -0,0 +1,109 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Linq; +using System.Reflection; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + internal sealed class TypePattern : Pattern + { + readonly string ns; + readonly string name; + + public TypePattern(Type type) + { + ns = type.Namespace; + name = type.Name; + } + + public override bool DoMatch(INode other, Match match) + { + ComposedType ct = other as ComposedType; + AstType o; + if (ct != null && !ct.HasNullableSpecifier && ct.PointerRank == 0 && !ct.ArraySpecifiers.Any()) { + // Special case: ILSpy sometimes produces a ComposedType but then removed all array specifiers + // from it. In that case, we need to look at the base type for the annotations. + o = ct.BaseType; + } else { + o = other as AstType; + if (o == null) + return false; + } + TypeReference tr = o.Annotation<TypeReference>(); + return tr != null && tr.Namespace == ns && tr.Name == name; + } + + public override string ToString() + { + return name; + } + } + + internal sealed class LdTokenPattern : Pattern + { + AnyNode childNode; + + public LdTokenPattern(string groupName) + { + childNode = new AnyNode(groupName); + } + + public override bool DoMatch(INode other, Match match) + { + InvocationExpression ie = other as InvocationExpression; + if (ie != null && ie.Annotation<LdTokenAnnotation>() != null && ie.Arguments.Count == 1) { + return childNode.DoMatch(ie.Arguments.Single(), match); + } + return false; + } + + public override string ToString() + { + return "ldtoken(...)"; + } + } + + /// <summary> + /// typeof-Pattern that applies on the expanded form of typeof (prior to ReplaceMethodCallsWithOperators) + /// </summary> + internal sealed class TypeOfPattern : Pattern + { + INode childNode; + + public TypeOfPattern(string groupName) + { + childNode = new TypePattern(typeof(Type)).ToType().Invoke( + "GetTypeFromHandle", new TypeOfExpression(new AnyNode(groupName)).Member("TypeHandle")); + } + + public override bool DoMatch(INode other, Match match) + { + return childNode.DoMatch(other, match); + } + + public override string ToString() + { + return "typeof(...)"; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/DecimalConstantTransform.cs b/ICSharpCode.Decompiler/Ast/Transforms/DecimalConstantTransform.cs new file mode 100644 index 00000000..298682af --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/DecimalConstantTransform.cs @@ -0,0 +1,58 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Transforms decimal constant fields. + /// </summary> + public class DecimalConstantTransform : DepthFirstAstVisitor<object, object>, IAstTransform + { + static readonly PrimitiveType decimalType = new PrimitiveType("decimal"); + + public override object VisitFieldDeclaration(FieldDeclaration fieldDeclaration, object data) + { + const Modifiers staticReadOnly = Modifiers.Static | Modifiers.Readonly; + if ((fieldDeclaration.Modifiers & staticReadOnly) == staticReadOnly && decimalType.IsMatch(fieldDeclaration.ReturnType)) { + foreach (var attributeSection in fieldDeclaration.Attributes) { + foreach (var attribute in attributeSection.Attributes) { + TypeReference tr = attribute.Type.Annotation<TypeReference>(); + if (tr != null && tr.Name == "DecimalConstantAttribute" && tr.Namespace == "System.Runtime.CompilerServices") { + attribute.Remove(); + if (attributeSection.Attributes.Count == 0) + attributeSection.Remove(); + fieldDeclaration.Modifiers = (fieldDeclaration.Modifiers & ~staticReadOnly) | Modifiers.Const; + return null; + } + } + } + } + return null; + } + + public void Run(AstNode compilationUnit) + { + compilationUnit.AcceptVisitor(this, null); + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/DeclareVariables.cs b/ICSharpCode.Decompiler/Ast/Transforms/DeclareVariables.cs new file mode 100644 index 00000000..a27f30b4 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/DeclareVariables.cs @@ -0,0 +1,368 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using ICSharpCode.Decompiler.ILAst; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.CSharp.Analysis; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Moves variable declarations to improved positions. + /// </summary> + public class DeclareVariables : IAstTransform + { + sealed class VariableToDeclare + { + public AstType Type; + public string Name; + public ILVariable ILVariable; + + public AssignmentExpression ReplacedAssignment; + public Statement InsertionPoint; + } + + readonly CancellationToken cancellationToken; + List<VariableToDeclare> variablesToDeclare = new List<VariableToDeclare>(); + + public DeclareVariables(DecompilerContext context) + { + cancellationToken = context.CancellationToken; + } + + public void Run(AstNode node) + { + Run(node, null); + // Declare all the variables at the end, after all the logic has run. + // This is done so that definite assignment analysis can work on a single representation and doesn't have to be updated + // when we change the AST. + foreach (var v in variablesToDeclare) { + if (v.ReplacedAssignment == null) { + BlockStatement block = (BlockStatement)v.InsertionPoint.Parent; + var decl = new VariableDeclarationStatement((AstType)v.Type.Clone(), v.Name); + if (v.ILVariable != null) + decl.Variables.Single().AddAnnotation(v.ILVariable); + block.Statements.InsertBefore( + v.InsertionPoint, + decl); + } + } + // First do all the insertions, then do all the replacements. This is necessary because a replacement might remove our reference point from the AST. + foreach (var v in variablesToDeclare) { + if (v.ReplacedAssignment != null) { + // We clone the right expression so that it doesn't get removed from the old ExpressionStatement, + // which might be still in use by the definite assignment graph. + VariableInitializer initializer = new VariableInitializer(v.Name, v.ReplacedAssignment.Right.Detach()).CopyAnnotationsFrom(v.ReplacedAssignment).WithAnnotation(v.ILVariable); + VariableDeclarationStatement varDecl = new VariableDeclarationStatement { + Type = (AstType)v.Type.Clone(), + Variables = { initializer } + }; + ExpressionStatement es = v.ReplacedAssignment.Parent as ExpressionStatement; + if (es != null) { + // Note: if this crashes with 'Cannot replace the root node', check whether two variables were assigned the same name + es.ReplaceWith(varDecl.CopyAnnotationsFrom(es)); + } else { + v.ReplacedAssignment.ReplaceWith(varDecl); + } + } + } + variablesToDeclare = null; + } + + void Run(AstNode node, DefiniteAssignmentAnalysis daa) + { + BlockStatement block = node as BlockStatement; + if (block != null) { + var variables = block.Statements.TakeWhile(stmt => stmt is VariableDeclarationStatement) + .Cast<VariableDeclarationStatement>().ToList(); + if (variables.Count > 0) { + // remove old variable declarations: + foreach (VariableDeclarationStatement varDecl in variables) { + Debug.Assert(varDecl.Variables.Single().Initializer.IsNull); + varDecl.Remove(); + } + if (daa == null) { + // If possible, reuse the DefiniteAssignmentAnalysis that was created for the parent block + daa = new DefiniteAssignmentAnalysis(block, cancellationToken); + } + foreach (VariableDeclarationStatement varDecl in variables) { + VariableInitializer initializer = varDecl.Variables.Single(); + string variableName = initializer.Name; + ILVariable v = initializer.Annotation<ILVariable>(); + bool allowPassIntoLoops = initializer.Annotation<DelegateConstruction.CapturedVariableAnnotation>() == null; + DeclareVariableInBlock(daa, block, varDecl.Type, variableName, v, allowPassIntoLoops); + } + } + } + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + Run(child, daa); + } + } + + void DeclareVariableInBlock(DefiniteAssignmentAnalysis daa, BlockStatement block, AstType type, string variableName, ILVariable v, bool allowPassIntoLoops) + { + // declarationPoint: The point where the variable would be declared, if we decide to declare it in this block + Statement declarationPoint = null; + // Check whether we can move down the variable into the sub-blocks + bool canMoveVariableIntoSubBlocks = FindDeclarationPoint(daa, variableName, allowPassIntoLoops, block, out declarationPoint); + if (declarationPoint == null) { + // The variable isn't used at all + return; + } + if (canMoveVariableIntoSubBlocks) { + // Declare the variable within the sub-blocks + foreach (Statement stmt in block.Statements) { + ForStatement forStmt = stmt as ForStatement; + if (forStmt != null && forStmt.Initializers.Count == 1) { + // handle the special case of moving a variable into the for initializer + if (TryConvertAssignmentExpressionIntoVariableDeclaration(forStmt.Initializers.Single(), type, variableName)) + continue; + } + UsingStatement usingStmt = stmt as UsingStatement; + if (usingStmt != null && usingStmt.ResourceAcquisition is AssignmentExpression) { + // handle the special case of moving a variable into a using statement + if (TryConvertAssignmentExpressionIntoVariableDeclaration((Expression)usingStmt.ResourceAcquisition, type, variableName)) + continue; + } + IfElseStatement ies = stmt as IfElseStatement; + if (ies != null) { + foreach (var child in IfElseChainChildren(ies)) { + BlockStatement subBlock = child as BlockStatement; + if (subBlock != null) + DeclareVariableInBlock(daa, subBlock, type, variableName, v, allowPassIntoLoops); + } + continue; + } + foreach (AstNode child in stmt.Children) { + BlockStatement subBlock = child as BlockStatement; + if (subBlock != null) { + DeclareVariableInBlock(daa, subBlock, type, variableName, v, allowPassIntoLoops); + } else if (HasNestedBlocks(child)) { + foreach (BlockStatement nestedSubBlock in child.Children.OfType<BlockStatement>()) { + DeclareVariableInBlock(daa, nestedSubBlock, type, variableName, v, allowPassIntoLoops); + } + } + } + } + } else { + // Try converting an assignment expression into a VariableDeclarationStatement + if (!TryConvertAssignmentExpressionIntoVariableDeclaration(declarationPoint, type, variableName)) { + // Declare the variable in front of declarationPoint + variablesToDeclare.Add(new VariableToDeclare { Type = type, Name = variableName, ILVariable = v, InsertionPoint = declarationPoint }); + } + } + } + + bool TryConvertAssignmentExpressionIntoVariableDeclaration(Statement declarationPoint, AstType type, string variableName) + { + // convert the declarationPoint into a VariableDeclarationStatement + ExpressionStatement es = declarationPoint as ExpressionStatement; + if (es != null) { + return TryConvertAssignmentExpressionIntoVariableDeclaration(es.Expression, type, variableName); + } + return false; + } + + bool TryConvertAssignmentExpressionIntoVariableDeclaration(Expression expression, AstType type, string variableName) + { + AssignmentExpression ae = expression as AssignmentExpression; + if (ae != null && ae.Operator == AssignmentOperatorType.Assign) { + IdentifierExpression ident = ae.Left as IdentifierExpression; + if (ident != null && ident.Identifier == variableName) { + variablesToDeclare.Add(new VariableToDeclare { Type = type, Name = variableName, ILVariable = ident.Annotation<ILVariable>(), ReplacedAssignment = ae }); + return true; + } + } + return false; + } + + /// <summary> + /// Finds the declaration point for the variable within the specified block. + /// </summary> + /// <param name="daa"> + /// Definite assignment analysis, must be prepared for 'block' or one of its parents. + /// </param> + /// <param name="varDecl">The variable to declare</param> + /// <param name="block">The block in which the variable should be declared</param> + /// <param name="declarationPoint"> + /// Output parameter: the first statement within 'block' where the variable needs to be declared. + /// </param> + /// <returns> + /// Returns whether it is possible to move the variable declaration into sub-blocks. + /// </returns> + public static bool FindDeclarationPoint(DefiniteAssignmentAnalysis daa, VariableDeclarationStatement varDecl, BlockStatement block, out Statement declarationPoint) + { + string variableName = varDecl.Variables.Single().Name; + bool allowPassIntoLoops = varDecl.Variables.Single().Annotation<DelegateConstruction.CapturedVariableAnnotation>() == null; + return FindDeclarationPoint(daa, variableName, allowPassIntoLoops, block, out declarationPoint); + } + + static bool FindDeclarationPoint(DefiniteAssignmentAnalysis daa, string variableName, bool allowPassIntoLoops, BlockStatement block, out Statement declarationPoint) + { + // declarationPoint: The point where the variable would be declared, if we decide to declare it in this block + declarationPoint = null; + foreach (Statement stmt in block.Statements) { + if (UsesVariable(stmt, variableName)) { + if (declarationPoint == null) + declarationPoint = stmt; + if (!CanMoveVariableUseIntoSubBlock(stmt, variableName, allowPassIntoLoops)) { + // If it's not possible to move the variable use into a nested block, + // we need to declare the variable in this block + return false; + } + // If we can move the variable into the sub-block, we need to ensure that the remaining code + // does not use the value that was assigned by the first sub-block + Statement nextStatement = stmt.GetNextStatement(); + if (nextStatement != null) { + // Analyze the range from the next statement to the end of the block + daa.SetAnalyzedRange(nextStatement, block); + daa.Analyze(variableName); + if (daa.UnassignedVariableUses.Count > 0) { + return false; + } + } + } + } + return true; + } + + static bool CanMoveVariableUseIntoSubBlock(Statement stmt, string variableName, bool allowPassIntoLoops) + { + if (!allowPassIntoLoops && (stmt is ForStatement || stmt is ForeachStatement || stmt is DoWhileStatement || stmt is WhileStatement)) + return false; + + ForStatement forStatement = stmt as ForStatement; + if (forStatement != null && forStatement.Initializers.Count == 1) { + // for-statement is special case: we can move variable declarations into the initializer + ExpressionStatement es = forStatement.Initializers.Single() as ExpressionStatement; + if (es != null) { + AssignmentExpression ae = es.Expression as AssignmentExpression; + if (ae != null && ae.Operator == AssignmentOperatorType.Assign) { + IdentifierExpression ident = ae.Left as IdentifierExpression; + if (ident != null && ident.Identifier == variableName) { + return !UsesVariable(ae.Right, variableName); + } + } + } + } + + UsingStatement usingStatement = stmt as UsingStatement; + if (usingStatement != null) { + // using-statement is special case: we can move variable declarations into the initializer + AssignmentExpression ae = usingStatement.ResourceAcquisition as AssignmentExpression; + if (ae != null && ae.Operator == AssignmentOperatorType.Assign) { + IdentifierExpression ident = ae.Left as IdentifierExpression; + if (ident != null && ident.Identifier == variableName) { + return !UsesVariable(ae.Right, variableName); + } + } + } + + IfElseStatement ies = stmt as IfElseStatement; + if (ies != null) { + foreach (var child in IfElseChainChildren(ies)) { + if (!(child is BlockStatement) && UsesVariable(child, variableName)) + return false; + } + return true; + } + + // We can move the variable into a sub-block only if the variable is used in only that sub-block (and not in expressions such as the loop condition) + for (AstNode child = stmt.FirstChild; child != null; child = child.NextSibling) { + if (!(child is BlockStatement) && UsesVariable(child, variableName)) { + if (HasNestedBlocks(child)) { + // catch clauses/switch sections can contain nested blocks + for (AstNode grandchild = child.FirstChild; grandchild != null; grandchild = grandchild.NextSibling) { + if (!(grandchild is BlockStatement) && UsesVariable(grandchild, variableName)) + return false; + } + } else { + return false; + } + } + } + return true; + } + + static IEnumerable<AstNode> IfElseChainChildren(IfElseStatement ies) + { + IfElseStatement prev; + do { + yield return ies.Condition; + yield return ies.TrueStatement; + prev = ies; + ies = ies.FalseStatement as IfElseStatement; + } while (ies != null); + if (!prev.FalseStatement.IsNull) + yield return prev.FalseStatement; + } + + static bool HasNestedBlocks(AstNode node) + { + return node is CatchClause || node is SwitchSection; + } + + static bool UsesVariable(AstNode node, string variableName) + { + IdentifierExpression ie = node as IdentifierExpression; + if (ie != null && ie.Identifier == variableName) + return true; + + FixedStatement fixedStatement = node as FixedStatement; + if (fixedStatement != null) { + foreach (VariableInitializer v in fixedStatement.Variables) { + if (v.Name == variableName) + return false; // no need to introduce the variable here + } + } + + ForeachStatement foreachStatement = node as ForeachStatement; + if (foreachStatement != null) { + if (foreachStatement.VariableName == variableName) + return false; // no need to introduce the variable here + } + + UsingStatement usingStatement = node as UsingStatement; + if (usingStatement != null) { + VariableDeclarationStatement varDecl = usingStatement.ResourceAcquisition as VariableDeclarationStatement; + if (varDecl != null) { + foreach (VariableInitializer v in varDecl.Variables) { + if (v.Name == variableName) + return false; // no need to introduce the variable here + } + } + } + + CatchClause catchClause = node as CatchClause; + if (catchClause != null && catchClause.VariableName == variableName) { + return false; // no need to introduce the variable here + } + + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + if (UsesVariable(child, variableName)) + return true; + } + return false; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs new file mode 100644 index 00000000..04b2293d --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs @@ -0,0 +1,502 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using ICSharpCode.Decompiler; +using ICSharpCode.Decompiler.ILAst; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Converts "new Action(obj, ldftn(func))" into "new Action(obj.func)". + /// For anonymous methods, creates an AnonymousMethodExpression. + /// Also gets rid of any "Display Classes" left over after inlining an anonymous method. + /// </summary> + public class DelegateConstruction : ContextTrackingVisitor<object> + { + internal sealed class Annotation + { + /// <summary> + /// ldftn or ldvirtftn? + /// </summary> + public readonly bool IsVirtual; + + public Annotation(bool isVirtual) + { + IsVirtual = isVirtual; + } + } + + internal sealed class CapturedVariableAnnotation + { + } + + List<string> currentlyUsedVariableNames = new List<string>(); + + public DelegateConstruction(DecompilerContext context) : base(context) + { + } + + public override object VisitObjectCreateExpression(ObjectCreateExpression objectCreateExpression, object data) + { + if (objectCreateExpression.Arguments.Count == 2) { + Expression obj = objectCreateExpression.Arguments.First(); + Expression func = objectCreateExpression.Arguments.Last(); + Annotation annotation = func.Annotation<Annotation>(); + if (annotation != null) { + IdentifierExpression methodIdent = (IdentifierExpression)((InvocationExpression)func).Arguments.Single(); + MethodReference method = methodIdent.Annotation<MethodReference>(); + if (method != null) { + if (HandleAnonymousMethod(objectCreateExpression, obj, method)) + return null; + // Perform the transformation to "new Action(obj.func)". + obj.Remove(); + methodIdent.Remove(); + if (!annotation.IsVirtual && obj is ThisReferenceExpression) { + // maybe it's getting the pointer of a base method? + if (method.DeclaringType.GetElementType() != context.CurrentType) { + obj = new BaseReferenceExpression(); + } + } + if (!annotation.IsVirtual && obj is NullReferenceExpression && !method.HasThis) { + // We're loading a static method. + // However it is possible to load extension methods with an instance, so we compare the number of arguments: + bool isExtensionMethod = false; + TypeReference delegateType = objectCreateExpression.Type.Annotation<TypeReference>(); + if (delegateType != null) { + TypeDefinition delegateTypeDef = delegateType.Resolve(); + if (delegateTypeDef != null) { + MethodDefinition invokeMethod = delegateTypeDef.Methods.FirstOrDefault(m => m.Name == "Invoke"); + if (invokeMethod != null) { + isExtensionMethod = (invokeMethod.Parameters.Count + 1 == method.Parameters.Count); + } + } + } + if (!isExtensionMethod) { + obj = new TypeReferenceExpression { Type = AstBuilder.ConvertType(method.DeclaringType) }; + } + } + // now transform the identifier into a member reference + MemberReferenceExpression mre = new MemberReferenceExpression(); + mre.Target = obj; + mre.MemberName = methodIdent.Identifier; + methodIdent.TypeArguments.MoveTo(mre.TypeArguments); + mre.AddAnnotation(method); + objectCreateExpression.Arguments.Clear(); + objectCreateExpression.Arguments.Add(mre); + return null; + } + } + } + return base.VisitObjectCreateExpression(objectCreateExpression, data); + } + + internal static bool IsAnonymousMethod(DecompilerContext context, MethodDefinition method) + { + if (method == null || !(method.HasGeneratedName() || method.Name.Contains("$"))) + return false; + if (!(method.IsCompilerGenerated() || IsPotentialClosure(context, method.DeclaringType))) + return false; + return true; + } + + bool HandleAnonymousMethod(ObjectCreateExpression objectCreateExpression, Expression target, MethodReference methodRef) + { + if (!context.Settings.AnonymousMethods) + return false; // anonymous method decompilation is disabled + if (target != null && !(target is IdentifierExpression || target is ThisReferenceExpression || target is NullReferenceExpression)) + return false; // don't copy arbitrary expressions, deal with identifiers only + + // Anonymous methods are defined in the same assembly + MethodDefinition method = methodRef.ResolveWithinSameModule(); + if (!IsAnonymousMethod(context, method)) + return false; + + // Create AnonymousMethodExpression and prepare parameters + AnonymousMethodExpression ame = new AnonymousMethodExpression(); + ame.CopyAnnotationsFrom(objectCreateExpression); // copy ILRanges etc. + ame.RemoveAnnotations<MethodReference>(); // remove reference to delegate ctor + ame.AddAnnotation(method); // add reference to anonymous method + ame.Parameters.AddRange(AstBuilder.MakeParameters(method, isLambda: true)); + ame.HasParameterList = true; + + // rename variables so that they don't conflict with the parameters: + foreach (ParameterDeclaration pd in ame.Parameters) { + EnsureVariableNameIsAvailable(objectCreateExpression, pd.Name); + } + + // Decompile the anonymous method: + + DecompilerContext subContext = context.Clone(); + subContext.CurrentMethod = method; + subContext.CurrentMethodIsAsync = false; + subContext.ReservedVariableNames.AddRange(currentlyUsedVariableNames); + BlockStatement body = AstMethodBodyBuilder.CreateMethodBody(method, subContext, ame.Parameters); + TransformationPipeline.RunTransformationsUntil(body, v => v is DelegateConstruction, subContext); + body.AcceptVisitor(this, null); + + + bool isLambda = false; + if (ame.Parameters.All(p => p.ParameterModifier == ParameterModifier.None)) { + isLambda = (body.Statements.Count == 1 && body.Statements.Single() is ReturnStatement); + } + // Remove the parameter list from an AnonymousMethodExpression if the original method had no names, + // and the parameters are not used in the method body + if (!isLambda && method.Parameters.All(p => string.IsNullOrEmpty(p.Name))) { + var parameterReferencingIdentifiers = + from ident in body.Descendants.OfType<IdentifierExpression>() + let v = ident.Annotation<ILVariable>() + where v != null && v.IsParameter && method.Parameters.Contains(v.OriginalParameter) + select ident; + if (!parameterReferencingIdentifiers.Any()) { + ame.Parameters.Clear(); + ame.HasParameterList = false; + } + } + + // Replace all occurrences of 'this' in the method body with the delegate's target: + foreach (AstNode node in body.Descendants) { + if (node is ThisReferenceExpression) + node.ReplaceWith(target.Clone()); + } + Expression replacement; + if (isLambda) { + LambdaExpression lambda = new LambdaExpression(); + lambda.CopyAnnotationsFrom(ame); + ame.Parameters.MoveTo(lambda.Parameters); + Expression returnExpr = ((ReturnStatement)body.Statements.Single()).Expression; + returnExpr.Remove(); + lambda.Body = returnExpr; + replacement = lambda; + } else { + ame.Body = body; + replacement = ame; + } + var expectedType = objectCreateExpression.Annotation<TypeInformation>().ExpectedType.Resolve(); + if (expectedType != null && !expectedType.IsDelegate()) { + var simplifiedDelegateCreation = (ObjectCreateExpression)objectCreateExpression.Clone(); + simplifiedDelegateCreation.Arguments.Clear(); + simplifiedDelegateCreation.Arguments.Add(replacement); + replacement = simplifiedDelegateCreation; + } + objectCreateExpression.ReplaceWith(replacement); + return true; + } + + internal static bool IsPotentialClosure(DecompilerContext context, TypeDefinition potentialDisplayClass) + { + if (potentialDisplayClass == null || !potentialDisplayClass.IsCompilerGeneratedOrIsInCompilerGeneratedClass()) + return false; + // check that methodContainingType is within containingType + while (potentialDisplayClass != context.CurrentType) { + potentialDisplayClass = potentialDisplayClass.DeclaringType; + if (potentialDisplayClass == null) + return false; + } + return true; + } + + public override object VisitInvocationExpression(InvocationExpression invocationExpression, object data) + { + if (context.Settings.ExpressionTrees && ExpressionTreeConverter.CouldBeExpressionTree(invocationExpression)) { + Expression converted = ExpressionTreeConverter.TryConvert(context, invocationExpression); + if (converted != null) { + invocationExpression.ReplaceWith(converted); + return converted.AcceptVisitor(this, data); + } + } + return base.VisitInvocationExpression(invocationExpression, data); + } + + #region Track current variables + public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) + { + Debug.Assert(currentlyUsedVariableNames.Count == 0); + try { + currentlyUsedVariableNames.AddRange(methodDeclaration.Parameters.Select(p => p.Name)); + return base.VisitMethodDeclaration(methodDeclaration, data); + } finally { + currentlyUsedVariableNames.Clear(); + } + } + + public override object VisitOperatorDeclaration(OperatorDeclaration operatorDeclaration, object data) + { + Debug.Assert(currentlyUsedVariableNames.Count == 0); + try { + currentlyUsedVariableNames.AddRange(operatorDeclaration.Parameters.Select(p => p.Name)); + return base.VisitOperatorDeclaration(operatorDeclaration, data); + } finally { + currentlyUsedVariableNames.Clear(); + } + } + + public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) + { + Debug.Assert(currentlyUsedVariableNames.Count == 0); + try { + currentlyUsedVariableNames.AddRange(constructorDeclaration.Parameters.Select(p => p.Name)); + return base.VisitConstructorDeclaration(constructorDeclaration, data); + } finally { + currentlyUsedVariableNames.Clear(); + } + } + + public override object VisitIndexerDeclaration(IndexerDeclaration indexerDeclaration, object data) + { + Debug.Assert(currentlyUsedVariableNames.Count == 0); + try { + currentlyUsedVariableNames.AddRange(indexerDeclaration.Parameters.Select(p => p.Name)); + return base.VisitIndexerDeclaration(indexerDeclaration, data); + } finally { + currentlyUsedVariableNames.Clear(); + } + } + + public override object VisitAccessor(Accessor accessor, object data) + { + try { + currentlyUsedVariableNames.Add("value"); + return base.VisitAccessor(accessor, data); + } finally { + currentlyUsedVariableNames.RemoveAt(currentlyUsedVariableNames.Count - 1); + } + } + + public override object VisitVariableDeclarationStatement(VariableDeclarationStatement variableDeclarationStatement, object data) + { + foreach (VariableInitializer v in variableDeclarationStatement.Variables) + currentlyUsedVariableNames.Add(v.Name); + return base.VisitVariableDeclarationStatement(variableDeclarationStatement, data); + } + + public override object VisitFixedStatement(FixedStatement fixedStatement, object data) + { + foreach (VariableInitializer v in fixedStatement.Variables) + currentlyUsedVariableNames.Add(v.Name); + return base.VisitFixedStatement(fixedStatement, data); + } + #endregion + + static readonly ExpressionStatement displayClassAssignmentPattern = + new ExpressionStatement(new AssignmentExpression( + new NamedNode("variable", new IdentifierExpression(Pattern.AnyString)), + new ObjectCreateExpression { Type = new AnyNode("type") } + )); + + public override object VisitBlockStatement(BlockStatement blockStatement, object data) + { + int numberOfVariablesOutsideBlock = currentlyUsedVariableNames.Count; + base.VisitBlockStatement(blockStatement, data); + foreach (ExpressionStatement stmt in blockStatement.Statements.OfType<ExpressionStatement>().ToArray()) { + Match displayClassAssignmentMatch = displayClassAssignmentPattern.Match(stmt); + if (!displayClassAssignmentMatch.Success) + continue; + + ILVariable variable = displayClassAssignmentMatch.Get<AstNode>("variable").Single().Annotation<ILVariable>(); + if (variable == null) + continue; + TypeDefinition type = variable.Type.ResolveWithinSameModule(); + if (!IsPotentialClosure(context, type)) + continue; + if (displayClassAssignmentMatch.Get<AstType>("type").Single().Annotation<TypeReference>().ResolveWithinSameModule() != type) + continue; + + // Looks like we found a display class creation. Now let's verify that the variable is used only for field accesses: + bool ok = true; + foreach (var identExpr in blockStatement.Descendants.OfType<IdentifierExpression>()) { + if (identExpr.Identifier == variable.Name && identExpr != displayClassAssignmentMatch.Get("variable").Single()) { + if (!(identExpr.Parent is MemberReferenceExpression && identExpr.Parent.Annotation<FieldReference>() != null)) + ok = false; + } + } + if (!ok) + continue; + Dictionary<FieldReference, AstNode> dict = new Dictionary<FieldReference, AstNode>(); + + // Delete the variable declaration statement: + VariableDeclarationStatement displayClassVarDecl = PatternStatementTransform.FindVariableDeclaration(stmt, variable.Name); + if (displayClassVarDecl != null) + displayClassVarDecl.Remove(); + + // Delete the assignment statement: + AstNode cur = stmt.NextSibling; + stmt.Remove(); + + // Delete any following statements as long as they assign parameters to the display class + BlockStatement rootBlock = blockStatement.Ancestors.OfType<BlockStatement>().LastOrDefault() ?? blockStatement; + List<ILVariable> parameterOccurrances = rootBlock.Descendants.OfType<IdentifierExpression>() + .Select(n => n.Annotation<ILVariable>()).Where(p => p != null && p.IsParameter).ToList(); + AstNode next; + for (; cur != null; cur = next) { + next = cur.NextSibling; + + // Test for the pattern: + // "variableName.MemberName = right;" + ExpressionStatement closureFieldAssignmentPattern = new ExpressionStatement( + new AssignmentExpression( + new NamedNode("left", new MemberReferenceExpression { + Target = new IdentifierExpression(variable.Name), + MemberName = Pattern.AnyString + }), + new AnyNode("right") + ) + ); + Match m = closureFieldAssignmentPattern.Match(cur); + if (m.Success) { + FieldDefinition fieldDef = m.Get<MemberReferenceExpression>("left").Single().Annotation<FieldReference>().ResolveWithinSameModule(); + AstNode right = m.Get<AstNode>("right").Single(); + bool isParameter = false; + bool isDisplayClassParentPointerAssignment = false; + if (right is ThisReferenceExpression) { + isParameter = true; + } else if (right is IdentifierExpression) { + // handle parameters only if the whole method contains no other occurrence except for 'right' + ILVariable v = right.Annotation<ILVariable>(); + isParameter = v.IsParameter && parameterOccurrances.Count(c => c == v) == 1; + if (!isParameter && IsPotentialClosure(context, v.Type.ResolveWithinSameModule())) { + // parent display class within the same method + // (closure2.localsX = closure1;) + isDisplayClassParentPointerAssignment = true; + } + } else if (right is MemberReferenceExpression) { + // copy of parent display class reference from an outer lambda + // closure2.localsX = this.localsY + MemberReferenceExpression mre = m.Get<MemberReferenceExpression>("right").Single(); + do { + // descend into the targets of the mre as long as the field types are closures + FieldDefinition fieldDef2 = mre.Annotation<FieldReference>().ResolveWithinSameModule(); + if (fieldDef2 == null || !IsPotentialClosure(context, fieldDef2.FieldType.ResolveWithinSameModule())) { + break; + } + // if we finally get to a this reference, it's copying a display class parent pointer + if (mre.Target is ThisReferenceExpression) { + isDisplayClassParentPointerAssignment = true; + } + mre = mre.Target as MemberReferenceExpression; + } while (mre != null); + } + if (isParameter || isDisplayClassParentPointerAssignment) { + dict[fieldDef] = right; + cur.Remove(); + } else { + break; + } + } else { + break; + } + } + + // Now create variables for all fields of the display class (except for those that we already handled as parameters) + List<Tuple<AstType, ILVariable>> variablesToDeclare = new List<Tuple<AstType, ILVariable>>(); + foreach (FieldDefinition field in type.Fields) { + if (field.IsStatic) + continue; // skip static fields + if (dict.ContainsKey(field)) // skip field if it already was handled as parameter + continue; + string capturedVariableName = field.Name; + if (capturedVariableName.StartsWith("$VB$Local_", StringComparison.Ordinal) && capturedVariableName.Length > 10) + capturedVariableName = capturedVariableName.Substring(10); + EnsureVariableNameIsAvailable(blockStatement, capturedVariableName); + currentlyUsedVariableNames.Add(capturedVariableName); + ILVariable ilVar = new ILVariable + { + IsGenerated = true, + Name = capturedVariableName, + Type = field.FieldType, + }; + variablesToDeclare.Add(Tuple.Create(AstBuilder.ConvertType(field.FieldType, field), ilVar)); + dict[field] = new IdentifierExpression(capturedVariableName).WithAnnotation(ilVar); + } + + // Now figure out where the closure was accessed and use the simpler replacement expression there: + foreach (var identExpr in blockStatement.Descendants.OfType<IdentifierExpression>()) { + if (identExpr.Identifier == variable.Name) { + MemberReferenceExpression mre = (MemberReferenceExpression)identExpr.Parent; + AstNode replacement; + if (dict.TryGetValue(mre.Annotation<FieldReference>().ResolveWithinSameModule(), out replacement)) { + mre.ReplaceWith(replacement.Clone()); + } + } + } + // Now insert the variable declarations (we can do this after the replacements only so that the scope detection works): + Statement insertionPoint = blockStatement.Statements.FirstOrDefault(); + foreach (var tuple in variablesToDeclare) { + var newVarDecl = new VariableDeclarationStatement(tuple.Item1, tuple.Item2.Name); + newVarDecl.Variables.Single().AddAnnotation(new CapturedVariableAnnotation()); + newVarDecl.Variables.Single().AddAnnotation(tuple.Item2); + blockStatement.Statements.InsertBefore(insertionPoint, newVarDecl); + } + } + currentlyUsedVariableNames.RemoveRange(numberOfVariablesOutsideBlock, currentlyUsedVariableNames.Count - numberOfVariablesOutsideBlock); + return null; + } + + void EnsureVariableNameIsAvailable(AstNode currentNode, string name) + { + int pos = currentlyUsedVariableNames.IndexOf(name); + if (pos < 0) { + // name is still available + return; + } + // Naming conflict. Let's rename the existing variable so that the field keeps the name from metadata. + NameVariables nv = new NameVariables(); + // Add currently used variable and parameter names + foreach (string nameInUse in currentlyUsedVariableNames) + nv.AddExistingName(nameInUse); + // variables declared in child nodes of this block + foreach (VariableInitializer vi in currentNode.Descendants.OfType<VariableInitializer>()) + nv.AddExistingName(vi.Name); + // parameters in child lambdas + foreach (ParameterDeclaration pd in currentNode.Descendants.OfType<ParameterDeclaration>()) + nv.AddExistingName(pd.Name); + + string newName = nv.GetAlternativeName(name); + currentlyUsedVariableNames[pos] = newName; + + // find top-most block + AstNode topMostBlock = currentNode.Ancestors.OfType<BlockStatement>().LastOrDefault() ?? currentNode; + + // rename identifiers + foreach (IdentifierExpression ident in topMostBlock.Descendants.OfType<IdentifierExpression>()) { + if (ident.Identifier == name) { + ident.Identifier = newName; + ILVariable v = ident.Annotation<ILVariable>(); + if (v != null) + v.Name = newName; + } + } + // rename variable declarations + foreach (VariableInitializer vi in topMostBlock.Descendants.OfType<VariableInitializer>()) { + if (vi.Name == name) { + vi.Name = newName; + ILVariable v = vi.Annotation<ILVariable>(); + if (v != null) + v.Name = newName; + } + } + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs b/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs new file mode 100644 index 00000000..2327200d --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs @@ -0,0 +1,875 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using ICSharpCode.Decompiler.ILAst; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + public class ExpressionTreeConverter + { + #region static TryConvert method + public static bool CouldBeExpressionTree(InvocationExpression expr) + { + if (expr != null && expr.Arguments.Count == 2) { + MethodReference mr = expr.Annotation<MethodReference>(); + return mr != null && mr.Name == "Lambda" && mr.DeclaringType.FullName == "System.Linq.Expressions.Expression"; + } + return false; + } + + public static Expression TryConvert(DecompilerContext context, Expression expr) + { + Expression converted = new ExpressionTreeConverter(context).Convert(expr); + if (converted != null) { + converted.AddAnnotation(new ExpressionTreeLambdaAnnotation()); + } + return converted; + } + #endregion + + readonly DecompilerContext context; + Stack<LambdaExpression> activeLambdas = new Stack<LambdaExpression>(); + + ExpressionTreeConverter(DecompilerContext context) + { + this.context = context; + } + + #region Main Convert method + Expression Convert(Expression expr) + { + InvocationExpression invocation = expr as InvocationExpression; + if (invocation != null) { + MethodReference mr = invocation.Annotation<MethodReference>(); + if (mr != null && mr.DeclaringType.FullName == "System.Linq.Expressions.Expression") { + switch (mr.Name) { + case "Add": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Add, false); + case "AddChecked": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Add, true); + case "AddAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Add, false); + case "AddAssignChecked": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Add, true); + case "And": + return ConvertBinaryOperator(invocation, BinaryOperatorType.BitwiseAnd); + case "AndAlso": + return ConvertBinaryOperator(invocation, BinaryOperatorType.ConditionalAnd); + case "AndAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.BitwiseAnd); + case "ArrayAccess": + case "ArrayIndex": + return ConvertArrayIndex(invocation); + case "ArrayLength": + return ConvertArrayLength(invocation); + case "Assign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Assign); + case "Call": + return ConvertCall(invocation); + case "Coalesce": + return ConvertBinaryOperator(invocation, BinaryOperatorType.NullCoalescing); + case "Condition": + return ConvertCondition(invocation); + case "Constant": + if (invocation.Arguments.Count >= 1) + return invocation.Arguments.First().Clone(); + else + return NotSupported(expr); + case "Convert": + return ConvertCast(invocation, false); + case "ConvertChecked": + return ConvertCast(invocation, true); + case "Divide": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Divide); + case "DivideAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Divide); + case "Equal": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Equality); + case "ExclusiveOr": + return ConvertBinaryOperator(invocation, BinaryOperatorType.ExclusiveOr); + case "ExclusiveOrAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.ExclusiveOr); + case "Field": + return ConvertField(invocation); + case "GreaterThan": + return ConvertBinaryOperator(invocation, BinaryOperatorType.GreaterThan); + case "GreaterThanOrEqual": + return ConvertBinaryOperator(invocation, BinaryOperatorType.GreaterThanOrEqual); + case "Invoke": + return ConvertInvoke(invocation); + case "Lambda": + return ConvertLambda(invocation); + case "LeftShift": + return ConvertBinaryOperator(invocation, BinaryOperatorType.ShiftLeft); + case "LeftShiftAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.ShiftLeft); + case "LessThan": + return ConvertBinaryOperator(invocation, BinaryOperatorType.LessThan); + case "LessThanOrEqual": + return ConvertBinaryOperator(invocation, BinaryOperatorType.LessThanOrEqual); + case "ListInit": + return ConvertListInit(invocation); + case "MemberInit": + return ConvertMemberInit(invocation); + case "Modulo": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Modulus); + case "ModuloAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Modulus); + case "Multiply": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Multiply, false); + case "MultiplyChecked": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Multiply, true); + case "MultiplyAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Multiply, false); + case "MultiplyAssignChecked": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Multiply, true); + case "Negate": + return ConvertUnaryOperator(invocation, UnaryOperatorType.Minus, false); + case "NegateChecked": + return ConvertUnaryOperator(invocation, UnaryOperatorType.Minus, true); + case "New": + return ConvertNewObject(invocation); + case "NewArrayBounds": + return ConvertNewArrayBounds(invocation); + case "NewArrayInit": + return ConvertNewArrayInit(invocation); + case "Not": + return ConvertUnaryOperator(invocation, UnaryOperatorType.Not); + case "NotEqual": + return ConvertBinaryOperator(invocation, BinaryOperatorType.InEquality); + case "OnesComplement": + return ConvertUnaryOperator(invocation, UnaryOperatorType.BitNot); + case "Or": + return ConvertBinaryOperator(invocation, BinaryOperatorType.BitwiseOr); + case "OrAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.BitwiseOr); + case "OrElse": + return ConvertBinaryOperator(invocation, BinaryOperatorType.ConditionalOr); + case "Property": + return ConvertProperty(invocation); + case "Quote": + if (invocation.Arguments.Count == 1) + return Convert(invocation.Arguments.Single()); + else + return NotSupported(invocation); + case "RightShift": + return ConvertBinaryOperator(invocation, BinaryOperatorType.ShiftRight); + case "RightShiftAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.ShiftRight); + case "Subtract": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Subtract, false); + case "SubtractChecked": + return ConvertBinaryOperator(invocation, BinaryOperatorType.Subtract, true); + case "SubtractAssign": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Subtract, false); + case "SubtractAssignChecked": + return ConvertAssignmentOperator(invocation, AssignmentOperatorType.Subtract, true); + case "TypeAs": + return ConvertTypeAs(invocation); + case "TypeIs": + return ConvertTypeIs(invocation); + } + } + } + IdentifierExpression ident = expr as IdentifierExpression; + if (ident != null) { + ILVariable v = ident.Annotation<ILVariable>(); + if (v != null) { + foreach (LambdaExpression lambda in activeLambdas) { + foreach (ParameterDeclaration p in lambda.Parameters) { + if (p.Annotation<ILVariable>() == v) + return new IdentifierExpression(p.Name).WithAnnotation(v); + } + } + } + } + return NotSupported(expr); + } + + Expression NotSupported(Expression expr) + { + Debug.WriteLine("Expression Tree Conversion Failed: '" + expr + "' is not supported"); + return null; + } + #endregion + + #region Convert Lambda + static readonly Expression emptyArrayPattern = new ArrayCreateExpression { + Type = new AnyNode(), + Arguments = { new PrimitiveExpression(0) } + }; + + Expression ConvertLambda(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + LambdaExpression lambda = new LambdaExpression(); + Expression body = invocation.Arguments.First(); + ArrayCreateExpression parameterArray = invocation.Arguments.Last() as ArrayCreateExpression; + if (parameterArray == null) + return NotSupported(invocation); + + var annotation = body.Annotation<ParameterDeclarationAnnotation>(); + if (annotation != null) { + lambda.Parameters.AddRange(annotation.Parameters); + } else { + // No parameter declaration annotation found. + if (!emptyArrayPattern.IsMatch(parameterArray)) + return null; + } + + activeLambdas.Push(lambda); + Expression convertedBody = Convert(body); + activeLambdas.Pop(); + if (convertedBody == null) + return null; + lambda.Body = convertedBody; + return lambda; + } + #endregion + + #region Convert Field + static readonly Expression getFieldFromHandlePattern = + new TypePattern(typeof(FieldInfo)).ToType().Invoke( + "GetFieldFromHandle", + new LdTokenPattern("field").ToExpression().Member("FieldHandle"), + new OptionalNode(new TypeOfExpression(new AnyNode("declaringType")).Member("TypeHandle")) + ); + + Expression ConvertField(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + Expression fieldInfoExpr = invocation.Arguments.ElementAt(1); + Match m = getFieldFromHandlePattern.Match(fieldInfoExpr); + if (!m.Success) + return NotSupported(invocation); + + FieldReference fr = m.Get<AstNode>("field").Single().Annotation<FieldReference>(); + if (fr == null) + return null; + + Expression target = invocation.Arguments.ElementAt(0); + Expression convertedTarget; + if (target is NullReferenceExpression) { + if (m.Has("declaringType")) + convertedTarget = new TypeReferenceExpression(m.Get<AstType>("declaringType").Single().Clone()); + else + convertedTarget = new TypeReferenceExpression(AstBuilder.ConvertType(fr.DeclaringType)); + } else { + convertedTarget = Convert(target); + if (convertedTarget == null) + return null; + } + + return convertedTarget.Member(fr.Name).WithAnnotation(fr); + } + #endregion + + #region Convert Property + static readonly Expression getMethodFromHandlePattern = + new TypePattern(typeof(MethodBase)).ToType().Invoke( + "GetMethodFromHandle", + new LdTokenPattern("method").ToExpression().Member("MethodHandle"), + new OptionalNode(new TypeOfExpression(new AnyNode("declaringType")).Member("TypeHandle")) + ).CastTo(new TypePattern(typeof(MethodInfo))); + + Expression ConvertProperty(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(1)); + if (!m.Success) + return NotSupported(invocation); + + MethodReference mr = m.Get<AstNode>("method").Single().Annotation<MethodReference>(); + if (mr == null) + return null; + + Expression target = invocation.Arguments.ElementAt(0); + Expression convertedTarget; + if (target is NullReferenceExpression) { + if (m.Has("declaringType")) + convertedTarget = new TypeReferenceExpression(m.Get<AstType>("declaringType").Single().Clone()); + else + convertedTarget = new TypeReferenceExpression(AstBuilder.ConvertType(mr.DeclaringType)); + } else { + convertedTarget = Convert(target); + if (convertedTarget == null) + return null; + } + + return convertedTarget.Member(GetPropertyName(mr)).WithAnnotation(mr); + } + + string GetPropertyName(MethodReference accessor) + { + string name = accessor.Name; + if (name.StartsWith("get_", StringComparison.Ordinal) || name.StartsWith("set_", StringComparison.Ordinal)) + name = name.Substring(4); + return name; + } + #endregion + + #region Convert Call + Expression ConvertCall(InvocationExpression invocation) + { + if (invocation.Arguments.Count < 2) + return NotSupported(invocation); + + Expression target; + int firstArgumentPosition; + + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(0)); + if (m.Success) { + target = null; + firstArgumentPosition = 1; + } else { + m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(1)); + if (!m.Success) + return NotSupported(invocation); + target = invocation.Arguments.ElementAt(0); + firstArgumentPosition = 2; + } + + MethodReference mr = m.Get<AstNode>("method").Single().Annotation<MethodReference>(); + if (mr == null) + return null; + + Expression convertedTarget; + if (target == null || target is NullReferenceExpression) { + // static method + if (m.Has("declaringType")) + convertedTarget = new TypeReferenceExpression(m.Get<AstType>("declaringType").Single().Clone()); + else + convertedTarget = new TypeReferenceExpression(AstBuilder.ConvertType(mr.DeclaringType)); + } else { + convertedTarget = Convert(target); + if (convertedTarget == null) + return null; + } + + MemberReferenceExpression mre = convertedTarget.Member(mr.Name); + GenericInstanceMethod gim = mr as GenericInstanceMethod; + if (gim != null) { + foreach (TypeReference tr in gim.GenericArguments) { + mre.TypeArguments.Add(AstBuilder.ConvertType(tr)); + } + } + IList<Expression> arguments = null; + if (invocation.Arguments.Count == firstArgumentPosition + 1) { + Expression argumentArray = invocation.Arguments.ElementAt(firstArgumentPosition); + arguments = ConvertExpressionsArray(argumentArray); + } + if (arguments == null) { + arguments = new List<Expression>(); + foreach (Expression argument in invocation.Arguments.Skip(firstArgumentPosition)) { + Expression convertedArgument = Convert(argument); + if (convertedArgument == null) + return null; + arguments.Add(convertedArgument); + } + } + MethodDefinition methodDef = mr.Resolve(); + if (methodDef != null && methodDef.IsGetter) { + PropertyDefinition indexer = AstMethodBodyBuilder.GetIndexer(methodDef); + if (indexer != null) + return new IndexerExpression(mre.Target.Detach(), arguments).WithAnnotation(indexer); + } + return new InvocationExpression(mre, arguments).WithAnnotation(mr); + } + + Expression ConvertInvoke(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + Expression convertedTarget = Convert(invocation.Arguments.ElementAt(0)); + IList<Expression> convertedArguments = ConvertExpressionsArray(invocation.Arguments.ElementAt(1)); + if (convertedTarget != null && convertedArguments != null) + return new InvocationExpression(convertedTarget, convertedArguments); + else + return null; + } + #endregion + + #region Convert Binary Operator + static readonly Pattern trueOrFalse = new Choice { + new PrimitiveExpression(true), + new PrimitiveExpression(false) + }; + + Expression ConvertBinaryOperator(InvocationExpression invocation, BinaryOperatorType op, bool? isChecked = null) + { + if (invocation.Arguments.Count < 2) + return NotSupported(invocation); + + Expression left = Convert(invocation.Arguments.ElementAt(0)); + if (left == null) + return null; + Expression right = Convert(invocation.Arguments.ElementAt(1)); + if (right == null) + return null; + + BinaryOperatorExpression boe = new BinaryOperatorExpression(left, op, right); + if (isChecked != null) + boe.AddAnnotation(isChecked.Value ? AddCheckedBlocks.CheckedAnnotation : AddCheckedBlocks.UncheckedAnnotation); + + switch (invocation.Arguments.Count) { + case 2: + return boe; + case 3: + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(2)); + if (m.Success) + return boe.WithAnnotation(m.Get<AstNode>("method").Single().Annotation<MethodReference>()); + else + return null; + case 4: + if (!trueOrFalse.IsMatch(invocation.Arguments.ElementAt(2))) + return null; + m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(3)); + if (m.Success) + return boe.WithAnnotation(m.Get<AstNode>("method").Single().Annotation<MethodReference>()); + else + return null; + default: + return NotSupported(invocation); + } + } + #endregion + + #region Convert Assignment Operator + Expression ConvertAssignmentOperator(InvocationExpression invocation, AssignmentOperatorType op, bool? isChecked = null) + { + return NotSupported(invocation); + } + #endregion + + #region Convert Unary Operator + Expression ConvertUnaryOperator(InvocationExpression invocation, UnaryOperatorType op, bool? isChecked = null) + { + if (invocation.Arguments.Count < 1) + return NotSupported(invocation); + + Expression expr = Convert(invocation.Arguments.ElementAt(0)); + if (expr == null) + return null; + + UnaryOperatorExpression uoe = new UnaryOperatorExpression(op, expr); + if (isChecked != null) + uoe.AddAnnotation(isChecked.Value ? AddCheckedBlocks.CheckedAnnotation : AddCheckedBlocks.UncheckedAnnotation); + + switch (invocation.Arguments.Count) { + case 1: + return uoe; + case 2: + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(1)); + if (m.Success) + return uoe.WithAnnotation(m.Get<AstNode>("method").Single().Annotation<MethodReference>()); + else + return null; + default: + return NotSupported(invocation); + } + } + #endregion + + #region Convert Condition Operator + Expression ConvertCondition(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 3) + return NotSupported(invocation); + + Expression condition = Convert(invocation.Arguments.ElementAt(0)); + Expression trueExpr = Convert(invocation.Arguments.ElementAt(1)); + Expression falseExpr = Convert(invocation.Arguments.ElementAt(2)); + if (condition != null && trueExpr != null && falseExpr != null) + return new ConditionalExpression(condition, trueExpr, falseExpr); + else + return null; + } + #endregion + + #region Convert New Object + static readonly Expression newObjectCtorPattern = new TypePattern(typeof(MethodBase)).ToType().Invoke + ( + "GetMethodFromHandle", + new LdTokenPattern("ctor").ToExpression().Member("MethodHandle"), + new OptionalNode(new TypeOfExpression(new AnyNode("declaringType")).Member("TypeHandle")) + ).CastTo(new TypePattern(typeof(ConstructorInfo))); + + Expression ConvertNewObject(InvocationExpression invocation) + { + if (invocation.Arguments.Count < 1 || invocation.Arguments.Count > 3) + return NotSupported(invocation); + + Match m = newObjectCtorPattern.Match(invocation.Arguments.First()); + if (!m.Success) + return NotSupported(invocation); + + MethodReference ctor = m.Get<AstNode>("ctor").Single().Annotation<MethodReference>(); + if (ctor == null) + return null; + + AstType declaringTypeNode; + TypeReference declaringType; + if (m.Has("declaringType")) { + declaringTypeNode = m.Get<AstType>("declaringType").Single().Clone(); + declaringType = declaringTypeNode.Annotation<TypeReference>(); + } else { + declaringTypeNode = AstBuilder.ConvertType(ctor.DeclaringType); + declaringType = ctor.DeclaringType; + } + if (declaringTypeNode == null) + return null; + + ObjectCreateExpression oce = new ObjectCreateExpression(declaringTypeNode); + if (invocation.Arguments.Count >= 2) { + IList<Expression> arguments = ConvertExpressionsArray(invocation.Arguments.ElementAtOrDefault(1)); + if (arguments == null) + return null; + oce.Arguments.AddRange(arguments); + } + if (invocation.Arguments.Count >= 3 && declaringType.IsAnonymousType()) { + MethodDefinition resolvedCtor = ctor.Resolve(); + if (resolvedCtor == null || resolvedCtor.Parameters.Count != oce.Arguments.Count) + return null; + AnonymousTypeCreateExpression atce = new AnonymousTypeCreateExpression(); + var arguments = oce.Arguments.ToArray(); + if (AstMethodBodyBuilder.CanInferAnonymousTypePropertyNamesFromArguments(arguments, resolvedCtor.Parameters)) { + oce.Arguments.MoveTo(atce.Initializers); + } else { + for (int i = 0; i < resolvedCtor.Parameters.Count; i++) { + atce.Initializers.Add( + new NamedExpression { + Name = resolvedCtor.Parameters[i].Name, + Expression = arguments[i].Detach() + }); + } + } + return atce; + } + + return oce; + } + #endregion + + #region Convert ListInit + static readonly Pattern elementInitArrayPattern = ArrayInitializationPattern( + typeof(System.Linq.Expressions.ElementInit), + new TypePattern(typeof(System.Linq.Expressions.Expression)).ToType().Invoke("ElementInit", new AnyNode("methodInfos"), new AnyNode("addArgumentsArrays")) + ); + + Expression ConvertListInit(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + ObjectCreateExpression oce = Convert(invocation.Arguments.ElementAt(0)) as ObjectCreateExpression; + if (oce == null) + return null; + Expression elementsArray = invocation.Arguments.ElementAt(1); + ArrayInitializerExpression initializer = ConvertElementInit(elementsArray); + if (initializer != null) { + oce.Initializer = initializer; + return oce; + } else { + return null; + } + } + + ArrayInitializerExpression ConvertElementInit(Expression elementsArray) + { + IList<Expression> elements = ConvertExpressionsArray(elementsArray); + if (elements != null) { + return new ArrayInitializerExpression(elements); + } + Match m = elementInitArrayPattern.Match(elementsArray); + if (!m.Success) + return null; + ArrayInitializerExpression result = new ArrayInitializerExpression(); + foreach (var elementInit in m.Get<Expression>("addArgumentsArrays")) { + IList<Expression> arguments = ConvertExpressionsArray(elementInit); + if (arguments == null) + return null; + result.Elements.Add(new ArrayInitializerExpression(arguments)); + } + return result; + } + #endregion + + #region Convert MemberInit + Expression ConvertMemberInit(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + ObjectCreateExpression oce = Convert(invocation.Arguments.ElementAt(0)) as ObjectCreateExpression; + if (oce == null) + return null; + Expression elementsArray = invocation.Arguments.ElementAt(1); + ArrayInitializerExpression bindings = ConvertMemberBindings(elementsArray); + if (bindings == null) + return null; + oce.Initializer = bindings; + return oce; + } + + static readonly Pattern memberBindingArrayPattern = ArrayInitializationPattern(typeof(System.Linq.Expressions.MemberBinding), new AnyNode("binding")); + static readonly INode expressionTypeReference = new TypeReferenceExpression(new TypePattern(typeof(System.Linq.Expressions.Expression))); + + ArrayInitializerExpression ConvertMemberBindings(Expression elementsArray) + { + Match m = memberBindingArrayPattern.Match(elementsArray); + if (!m.Success) + return null; + ArrayInitializerExpression result = new ArrayInitializerExpression(); + foreach (var binding in m.Get<Expression>("binding")) { + InvocationExpression bindingInvocation = binding as InvocationExpression; + if (bindingInvocation == null || bindingInvocation.Arguments.Count != 2) + return null; + MemberReferenceExpression bindingMRE = bindingInvocation.Target as MemberReferenceExpression; + if (bindingMRE == null || !expressionTypeReference.IsMatch(bindingMRE.Target)) + return null; + + Expression bindingTarget = bindingInvocation.Arguments.ElementAt(0); + Expression bindingValue = bindingInvocation.Arguments.ElementAt(1); + + string memberName; + Match m2 = getMethodFromHandlePattern.Match(bindingTarget); + if (m2.Success) { + MethodReference setter = m2.Get<AstNode>("method").Single().Annotation<MethodReference>(); + if (setter == null) + return null; + memberName = GetPropertyName(setter); + } else { + return null; + } + + Expression convertedValue; + switch (bindingMRE.MemberName) { + case "Bind": + convertedValue = Convert(bindingValue); + break; + case "MemberBind": + convertedValue = ConvertMemberBindings(bindingValue); + break; + case "ListBind": + convertedValue = ConvertElementInit(bindingValue); + break; + default: + return null; + } + if (convertedValue == null) + return null; + result.Elements.Add(new NamedExpression(memberName, convertedValue)); + } + return result; + } + #endregion + + #region Convert Cast + Expression ConvertCast(InvocationExpression invocation, bool isChecked) + { + if (invocation.Arguments.Count < 2) + return null; + Expression converted = Convert(invocation.Arguments.ElementAt(0)); + AstType type = ConvertTypeReference(invocation.Arguments.ElementAt(1)); + if (converted != null && type != null) { + CastExpression cast = converted.CastTo(type); + cast.AddAnnotation(isChecked ? AddCheckedBlocks.CheckedAnnotation : AddCheckedBlocks.UncheckedAnnotation); + switch (invocation.Arguments.Count) { + case 2: + return cast; + case 3: + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(2)); + if (m.Success) + return cast.WithAnnotation(m.Get<AstNode>("method").Single().Annotation<MethodReference>()); + else + return null; + } + } + return null; + } + #endregion + + #region ConvertExpressionsArray + static Pattern ArrayInitializationPattern(Type arrayElementType, INode elementPattern) + { + return new Choice { + new ArrayCreateExpression { + Type = new TypePattern(arrayElementType), + Arguments = { new PrimitiveExpression(0) } + }, + new ArrayCreateExpression { + Type = new TypePattern(arrayElementType), + AdditionalArraySpecifiers = { new ArraySpecifier() }, + Initializer = new ArrayInitializerExpression { + Elements = { new Repeat(elementPattern) } + } + } + }; + } + + static readonly Pattern expressionArrayPattern = ArrayInitializationPattern(typeof(System.Linq.Expressions.Expression), new AnyNode("elements")); + + IList<Expression> ConvertExpressionsArray(Expression arrayExpression) + { + Match m = expressionArrayPattern.Match(arrayExpression); + if (m.Success) { + List<Expression> result = new List<Expression>(); + foreach (Expression expr in m.Get<Expression>("elements")) { + Expression converted = Convert(expr); + if (converted == null) + return null; + result.Add(converted); + } + return result; + } + return null; + } + #endregion + + #region Convert TypeAs/TypeIs + static readonly TypeOfPattern typeOfPattern = new TypeOfPattern("type"); + + AstType ConvertTypeReference(Expression typeOfExpression) + { + Match m = typeOfPattern.Match(typeOfExpression); + if (m.Success) + return m.Get<AstType>("type").Single().Clone(); + else + return null; + } + + Expression ConvertTypeAs(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return null; + Expression converted = Convert(invocation.Arguments.ElementAt(0)); + AstType type = ConvertTypeReference(invocation.Arguments.ElementAt(1)); + if (converted != null && type != null) + return new AsExpression(converted, type); + return null; + } + + Expression ConvertTypeIs(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return null; + Expression converted = Convert(invocation.Arguments.ElementAt(0)); + AstType type = ConvertTypeReference(invocation.Arguments.ElementAt(1)); + if (converted != null && type != null) + return new IsExpression { Expression = converted, Type = type }; + return null; + } + #endregion + + #region Convert Array + Expression ConvertArrayIndex(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + Expression targetConverted = Convert(invocation.Arguments.First()); + if (targetConverted == null) + return null; + + Expression index = invocation.Arguments.ElementAt(1); + Expression indexConverted = Convert(index); + if (indexConverted != null) { + return new IndexerExpression(targetConverted, indexConverted); + } + IList<Expression> indexesConverted = ConvertExpressionsArray(index); + if (indexesConverted != null) { + return new IndexerExpression(targetConverted, indexesConverted); + } + return null; + } + + Expression ConvertArrayLength(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 1) + return NotSupported(invocation); + + Expression targetConverted = Convert(invocation.Arguments.Single()); + if (targetConverted != null) + return targetConverted.Member("Length"); + else + return null; + } + + Expression ConvertNewArrayInit(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + AstType elementType = ConvertTypeReference(invocation.Arguments.ElementAt(0)); + IList<Expression> elements = ConvertExpressionsArray(invocation.Arguments.ElementAt(1)); + if (elementType != null && elements != null) { + if (ContainsAnonymousType(elementType)) { + elementType = null; + } + return new ArrayCreateExpression { + Type = elementType, + AdditionalArraySpecifiers = { new ArraySpecifier() }, + Initializer = new ArrayInitializerExpression(elements) + }; + } + return null; + } + + Expression ConvertNewArrayBounds(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + AstType elementType = ConvertTypeReference(invocation.Arguments.ElementAt(0)); + IList<Expression> arguments = ConvertExpressionsArray(invocation.Arguments.ElementAt(1)); + if (elementType != null && arguments != null) { + if (ContainsAnonymousType(elementType)) { + elementType = null; + } + ArrayCreateExpression ace = new ArrayCreateExpression(); + ace.Type = elementType; + ace.Arguments.AddRange(arguments); + return ace; + } + return null; + } + + bool ContainsAnonymousType(AstType type) + { + foreach (AstType t in type.DescendantsAndSelf.OfType<AstType>()) { + TypeReference tr = t.Annotation<TypeReference>(); + if (tr != null && tr.IsAnonymousType()) + return true; + } + return false; + } + #endregion + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/FlattenSwitchBlocks.cs b/ICSharpCode.Decompiler/Ast/Transforms/FlattenSwitchBlocks.cs new file mode 100644 index 00000000..9595e81b --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/FlattenSwitchBlocks.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + internal class FlattenSwitchBlocks : IAstTransform + { + public void Run(AstNode compilationUnit) + { + foreach (var switchSection in compilationUnit.Descendants.OfType<SwitchSection>()) + { + if (switchSection.Statements.Count != 1) + continue; + + var blockStatement = switchSection.Statements.First() as BlockStatement; + if (blockStatement == null || blockStatement.Statements.Any(st => st is VariableDeclarationStatement)) + continue; + + blockStatement.Remove(); + blockStatement.Statements.MoveTo(switchSection.Statements); + } + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/IntroduceExtensionMethods.cs b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceExtensionMethods.cs new file mode 100644 index 00000000..9f05285e --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceExtensionMethods.cs @@ -0,0 +1,66 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Converts extension method calls into infix syntax. + /// </summary> + public class IntroduceExtensionMethods : IAstTransform + { + readonly DecompilerContext context; + + public IntroduceExtensionMethods(DecompilerContext context) + { + this.context = context; + } + + public void Run(AstNode compilationUnit) + { + foreach (InvocationExpression invocation in compilationUnit.Descendants.OfType<InvocationExpression>()) { + MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression; + MethodReference methodReference = invocation.Annotation<MethodReference>(); + if (mre != null && mre.Target is TypeReferenceExpression && methodReference != null && invocation.Arguments.Any()) { + MethodDefinition d = methodReference.Resolve(); + if (d != null) { + foreach (var ca in d.CustomAttributes) { + if (ca.AttributeType.Name == "ExtensionAttribute" && ca.AttributeType.Namespace == "System.Runtime.CompilerServices") { + var firstArgument = invocation.Arguments.First(); + if (firstArgument is NullReferenceExpression) + firstArgument = firstArgument.ReplaceWith(expr => expr.CastTo(AstBuilder.ConvertType(d.Parameters.First().ParameterType))); + else + mre.Target = firstArgument.Detach(); + if (invocation.Arguments.Any()) { + // HACK: removing type arguments should be done indepently from whether a method is an extension method, + // just by testing whether the arguments can be inferred + mre.TypeArguments.Clear(); + } + break; + } + } + } + } + } + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/IntroduceQueryExpressions.cs b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceQueryExpressions.cs new file mode 100644 index 00000000..0da56fe9 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceQueryExpressions.cs @@ -0,0 +1,295 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Diagnostics; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Decompiles query expressions. + /// Based on C# 4.0 spec, §7.16.2 Query expression translation + /// </summary> + public class IntroduceQueryExpressions : IAstTransform + { + readonly DecompilerContext context; + + public IntroduceQueryExpressions(DecompilerContext context) + { + this.context = context; + } + + public void Run(AstNode compilationUnit) + { + if (!context.Settings.QueryExpressions) + return; + DecompileQueries(compilationUnit); + // After all queries were decompiled, detect degenerate queries (queries not property terminated with 'select' or 'group') + // and fix them, either by adding a degenerate select, or by combining them with another query. + foreach (QueryExpression query in compilationUnit.Descendants.OfType<QueryExpression>()) { + QueryFromClause fromClause = (QueryFromClause)query.Clauses.First(); + if (IsDegenerateQuery(query)) { + // introduce select for degenerate query + query.Clauses.Add(new QuerySelectClause { Expression = new IdentifierExpression(fromClause.Identifier) }); + } + // See if the data source of this query is a degenerate query, + // and combine the queries if possible. + QueryExpression innerQuery = fromClause.Expression as QueryExpression; + while (IsDegenerateQuery(innerQuery)) { + QueryFromClause innerFromClause = (QueryFromClause)innerQuery.Clauses.First(); + if (fromClause.Identifier != innerFromClause.Identifier) + break; + // Replace the fromClause with all clauses from the inner query + fromClause.Remove(); + QueryClause insertionPos = null; + foreach (var clause in innerQuery.Clauses) { + query.Clauses.InsertAfter(insertionPos, insertionPos = clause.Detach()); + } + fromClause = innerFromClause; + innerQuery = fromClause.Expression as QueryExpression; + } + } + } + + bool IsDegenerateQuery(QueryExpression query) + { + if (query == null) + return false; + var lastClause = query.Clauses.LastOrDefault(); + return !(lastClause is QuerySelectClause || lastClause is QueryGroupClause); + } + + void DecompileQueries(AstNode node) + { + QueryExpression query = DecompileQuery(node as InvocationExpression); + if (query != null) + node.ReplaceWith(query); + for (AstNode child = (query ?? node).FirstChild; child != null; child = child.NextSibling) { + DecompileQueries(child); + } + } + + QueryExpression DecompileQuery(InvocationExpression invocation) + { + if (invocation == null) + return null; + MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression; + if (mre == null) + return null; + switch (mre.MemberName) { + case "Select": + { + if (invocation.Arguments.Count != 1) + return null; + string parameterName; + Expression body; + if (MatchSimpleLambda(invocation.Arguments.Single(), out parameterName, out body)) { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = parameterName, Expression = mre.Target.Detach() }); + query.Clauses.Add(new QuerySelectClause { Expression = body.Detach() }); + return query; + } + return null; + } + case "GroupBy": + { + if (invocation.Arguments.Count == 2) { + string parameterName1, parameterName2; + Expression keySelector, elementSelector; + if (MatchSimpleLambda(invocation.Arguments.ElementAt(0), out parameterName1, out keySelector) + && MatchSimpleLambda(invocation.Arguments.ElementAt(1), out parameterName2, out elementSelector) + && parameterName1 == parameterName2) + { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = parameterName1, Expression = mre.Target.Detach() }); + query.Clauses.Add(new QueryGroupClause { Projection = elementSelector.Detach(), Key = keySelector.Detach() }); + return query; + } + } else if (invocation.Arguments.Count == 1) { + string parameterName; + Expression keySelector; + if (MatchSimpleLambda(invocation.Arguments.Single(), out parameterName, out keySelector)) { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = parameterName, Expression = mre.Target.Detach() }); + query.Clauses.Add(new QueryGroupClause { Projection = new IdentifierExpression(parameterName), Key = keySelector.Detach() }); + return query; + } + } + return null; + } + case "SelectMany": + { + if (invocation.Arguments.Count != 2) + return null; + string parameterName; + Expression collectionSelector; + if (!MatchSimpleLambda(invocation.Arguments.ElementAt(0), out parameterName, out collectionSelector)) + return null; + LambdaExpression lambda = invocation.Arguments.ElementAt(1) as LambdaExpression; + if (lambda != null && lambda.Parameters.Count == 2 && lambda.Body is Expression) { + ParameterDeclaration p1 = lambda.Parameters.ElementAt(0); + ParameterDeclaration p2 = lambda.Parameters.ElementAt(1); + if (p1.Name == parameterName) { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = p1.Name, Expression = mre.Target.Detach() }); + query.Clauses.Add(new QueryFromClause { Identifier = p2.Name, Expression = collectionSelector.Detach() }); + query.Clauses.Add(new QuerySelectClause { Expression = ((Expression)lambda.Body).Detach() }); + return query; + } + } + return null; + } + case "Where": + { + if (invocation.Arguments.Count != 1) + return null; + string parameterName; + Expression body; + if (MatchSimpleLambda(invocation.Arguments.Single(), out parameterName, out body)) { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = parameterName, Expression = mre.Target.Detach() }); + query.Clauses.Add(new QueryWhereClause { Condition = body.Detach() }); + return query; + } + return null; + } + case "OrderBy": + case "OrderByDescending": + case "ThenBy": + case "ThenByDescending": + { + if (invocation.Arguments.Count != 1) + return null; + string parameterName; + Expression orderExpression; + if (MatchSimpleLambda(invocation.Arguments.Single(), out parameterName, out orderExpression)) { + if (ValidateThenByChain(invocation, parameterName)) { + QueryOrderClause orderClause = new QueryOrderClause(); + InvocationExpression tmp = invocation; + while (mre.MemberName == "ThenBy" || mre.MemberName == "ThenByDescending") { + // insert new ordering at beginning + orderClause.Orderings.InsertAfter( + null, new QueryOrdering { + Expression = orderExpression.Detach(), + Direction = (mre.MemberName == "ThenBy" ? QueryOrderingDirection.None : QueryOrderingDirection.Descending) + }); + + tmp = (InvocationExpression)mre.Target; + mre = (MemberReferenceExpression)tmp.Target; + MatchSimpleLambda(tmp.Arguments.Single(), out parameterName, out orderExpression); + } + // insert new ordering at beginning + orderClause.Orderings.InsertAfter( + null, new QueryOrdering { + Expression = orderExpression.Detach(), + Direction = (mre.MemberName == "OrderBy" ? QueryOrderingDirection.None : QueryOrderingDirection.Descending) + }); + + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = parameterName, Expression = mre.Target.Detach() }); + query.Clauses.Add(orderClause); + return query; + } + } + return null; + } + case "Join": + case "GroupJoin": + { + if (invocation.Arguments.Count != 4) + return null; + Expression source1 = mre.Target; + Expression source2 = invocation.Arguments.ElementAt(0); + string elementName1, elementName2; + Expression key1, key2; + if (!MatchSimpleLambda(invocation.Arguments.ElementAt(1), out elementName1, out key1)) + return null; + if (!MatchSimpleLambda(invocation.Arguments.ElementAt(2), out elementName2, out key2)) + return null; + LambdaExpression lambda = invocation.Arguments.ElementAt(3) as LambdaExpression; + if (lambda != null && lambda.Parameters.Count == 2 && lambda.Body is Expression) { + ParameterDeclaration p1 = lambda.Parameters.ElementAt(0); + ParameterDeclaration p2 = lambda.Parameters.ElementAt(1); + if (p1.Name == elementName1 && (p2.Name == elementName2 || mre.MemberName == "GroupJoin")) { + QueryExpression query = new QueryExpression(); + query.Clauses.Add(new QueryFromClause { Identifier = elementName1, Expression = source1.Detach() }); + QueryJoinClause joinClause = new QueryJoinClause(); + joinClause.JoinIdentifier = elementName2; // join elementName2 + joinClause.InExpression = source2.Detach(); // in source2 + joinClause.OnExpression = key1.Detach(); // on key1 + joinClause.EqualsExpression = key2.Detach(); // equals key2 + if (mre.MemberName == "GroupJoin") { + joinClause.IntoIdentifier = p2.Name; // into p2.Name + } + query.Clauses.Add(joinClause); + query.Clauses.Add(new QuerySelectClause { Expression = ((Expression)lambda.Body).Detach() }); + return query; + } + } + return null; + } + default: + return null; + } + } + + /// <summary> + /// Ensure that all ThenBy's are correct, and that the list of ThenBy's is terminated by an 'OrderBy' invocation. + /// </summary> + bool ValidateThenByChain(InvocationExpression invocation, string expectedParameterName) + { + if (invocation == null || invocation.Arguments.Count != 1) + return false; + MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression; + if (mre == null) + return false; + string parameterName; + Expression body; + if (!MatchSimpleLambda(invocation.Arguments.Single(), out parameterName, out body)) + return false; + if (parameterName != expectedParameterName) + return false; + + if (mre.MemberName == "OrderBy" || mre.MemberName == "OrderByDescending") + return true; + else if (mre.MemberName == "ThenBy" || mre.MemberName == "ThenByDescending") + return ValidateThenByChain(mre.Target as InvocationExpression, expectedParameterName); + else + return false; + } + + /// <summary>Matches simple lambdas of the form "a => b"</summary> + bool MatchSimpleLambda(Expression expr, out string parameterName, out Expression body) + { + LambdaExpression lambda = expr as LambdaExpression; + if (lambda != null && lambda.Parameters.Count == 1 && lambda.Body is Expression) { + ParameterDeclaration p = lambda.Parameters.Single(); + if (p.ParameterModifier == ParameterModifier.None) { + parameterName = p.Name; + body = (Expression)lambda.Body; + return true; + } + } + parameterName = null; + body = null; + return false; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUnsafeModifier.cs b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUnsafeModifier.cs new file mode 100644 index 00000000..43548e38 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUnsafeModifier.cs @@ -0,0 +1,106 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + public class IntroduceUnsafeModifier : DepthFirstAstVisitor<object, bool>, IAstTransform + { + public static readonly object PointerArithmeticAnnotation = new PointerArithmetic(); + + sealed class PointerArithmetic {} + + public void Run(AstNode compilationUnit) + { + compilationUnit.AcceptVisitor(this, null); + } + + protected override bool VisitChildren(AstNode node, object data) + { + bool result = false; + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + result |= child.AcceptVisitor(this, data); + } + if (result && node is EntityDeclaration && !(node is Accessor)) { + ((EntityDeclaration)node).Modifiers |= Modifiers.Unsafe; + return false; + } + return result; + } + + public override bool VisitPointerReferenceExpression(PointerReferenceExpression pointerReferenceExpression, object data) + { + base.VisitPointerReferenceExpression(pointerReferenceExpression, data); + return true; + } + + public override bool VisitComposedType(ComposedType composedType, object data) + { + if (composedType.PointerRank > 0) + return true; + else + return base.VisitComposedType(composedType, data); + } + + public override bool VisitUnaryOperatorExpression(UnaryOperatorExpression unaryOperatorExpression, object data) + { + bool result = base.VisitUnaryOperatorExpression(unaryOperatorExpression, data); + if (unaryOperatorExpression.Operator == UnaryOperatorType.Dereference) { + BinaryOperatorExpression bop = unaryOperatorExpression.Expression as BinaryOperatorExpression; + if (bop != null && bop.Operator == BinaryOperatorType.Add && bop.Annotation<PointerArithmetic>() != null) { + // transform "*(ptr + int)" to "ptr[int]" + IndexerExpression indexer = new IndexerExpression(); + indexer.Target = bop.Left.Detach(); + indexer.Arguments.Add(bop.Right.Detach()); + indexer.CopyAnnotationsFrom(unaryOperatorExpression); + indexer.CopyAnnotationsFrom(bop); + unaryOperatorExpression.ReplaceWith(indexer); + } + return true; + } else if (unaryOperatorExpression.Operator == UnaryOperatorType.AddressOf) { + return true; + } else { + return result; + } + } + + public override bool VisitMemberReferenceExpression(MemberReferenceExpression memberReferenceExpression, object data) + { + bool result = base.VisitMemberReferenceExpression(memberReferenceExpression, data); + UnaryOperatorExpression uoe = memberReferenceExpression.Target as UnaryOperatorExpression; + if (uoe != null && uoe.Operator == UnaryOperatorType.Dereference) { + PointerReferenceExpression pre = new PointerReferenceExpression(); + pre.Target = uoe.Expression.Detach(); + pre.MemberName = memberReferenceExpression.MemberName; + memberReferenceExpression.TypeArguments.MoveTo(pre.TypeArguments); + pre.CopyAnnotationsFrom(uoe); + pre.CopyAnnotationsFrom(memberReferenceExpression); + memberReferenceExpression.ReplaceWith(pre); + } + return result; + } + + public override bool VisitStackAllocExpression(StackAllocExpression stackAllocExpression, object data) + { + base.VisitStackAllocExpression(stackAllocExpression, data); + return true; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs new file mode 100644 index 00000000..6e9cc4f5 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs @@ -0,0 +1,359 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Introduces using declarations. + /// </summary> + public class IntroduceUsingDeclarations : IAstTransform + { + DecompilerContext context; + + public IntroduceUsingDeclarations(DecompilerContext context) + { + this.context = context; + } + + public void Run(AstNode compilationUnit) + { + // First determine all the namespaces that need to be imported: + compilationUnit.AcceptVisitor(new FindRequiredImports(this), null); + + importedNamespaces.Add("System"); // always import System, even when not necessary + + if (context.Settings.UsingDeclarations) { + // Now add using declarations for those namespaces: + foreach (string ns in importedNamespaces.OrderByDescending(n => n)) { + // we go backwards (OrderByDescending) through the list of namespaces because we insert them backwards + // (always inserting at the start of the list) + string[] parts = ns.Split('.'); + AstType nsType = new SimpleType(parts[0]); + for (int i = 1; i < parts.Length; i++) { + nsType = new MemberType { Target = nsType, MemberName = parts[i] }; + } + compilationUnit.InsertChildAfter(null, new UsingDeclaration { Import = nsType }, SyntaxTree.MemberRole); + } + } + + if (!context.Settings.FullyQualifyAmbiguousTypeNames) + return; + + FindAmbiguousTypeNames(context.CurrentModule, internalsVisible: true); + foreach (AssemblyNameReference r in context.CurrentModule.AssemblyReferences) { + AssemblyDefinition d = context.CurrentModule.AssemblyResolver.Resolve(r); + if (d != null) + FindAmbiguousTypeNames(d.MainModule, internalsVisible: false); + } + + // verify that the SimpleTypes refer to the correct type (no ambiguities) + compilationUnit.AcceptVisitor(new FullyQualifyAmbiguousTypeNamesVisitor(this), null); + } + + readonly HashSet<string> declaredNamespaces = new HashSet<string>() { string.Empty }; + readonly HashSet<string> importedNamespaces = new HashSet<string>(); + + // Note that we store type names with `n suffix, so we automatically disambiguate based on number of type parameters. + readonly HashSet<string> availableTypeNames = new HashSet<string>(); + readonly HashSet<string> ambiguousTypeNames = new HashSet<string>(); + + sealed class FindRequiredImports : DepthFirstAstVisitor<object, object> + { + readonly IntroduceUsingDeclarations transform; + string currentNamespace; + + public FindRequiredImports(IntroduceUsingDeclarations transform) + { + this.transform = transform; + currentNamespace = transform.context.CurrentType != null ? transform.context.CurrentType.Namespace : string.Empty; + } + + bool IsParentOfCurrentNamespace(string ns) + { + if (ns.Length == 0) + return true; + if (currentNamespace.StartsWith(ns, StringComparison.Ordinal)) { + if (currentNamespace.Length == ns.Length) + return true; + if (currentNamespace[ns.Length] == '.') + return true; + } + return false; + } + + public override object VisitSimpleType(SimpleType simpleType, object data) + { + TypeReference tr = simpleType.Annotation<TypeReference>(); + if (tr != null && !IsParentOfCurrentNamespace(tr.Namespace)) { + transform.importedNamespaces.Add(tr.Namespace); + } + return base.VisitSimpleType(simpleType, data); // also visit type arguments + } + + public override object VisitNamespaceDeclaration(NamespaceDeclaration namespaceDeclaration, object data) + { + string oldNamespace = currentNamespace; + foreach (string ident in namespaceDeclaration.Identifiers) { + currentNamespace = NamespaceDeclaration.BuildQualifiedName(currentNamespace, ident); + transform.declaredNamespaces.Add(currentNamespace); + } + base.VisitNamespaceDeclaration(namespaceDeclaration, data); + currentNamespace = oldNamespace; + return null; + } + } + + void FindAmbiguousTypeNames(ModuleDefinition module, bool internalsVisible) + { + foreach (TypeDefinition type in module.Types) { + if (internalsVisible || type.IsPublic) { + if (importedNamespaces.Contains(type.Namespace) || declaredNamespaces.Contains(type.Namespace)) { + if (!availableTypeNames.Add(type.Name)) + ambiguousTypeNames.Add(type.Name); + } + } + } + } + + sealed class FullyQualifyAmbiguousTypeNamesVisitor : DepthFirstAstVisitor<object, object> + { + readonly IntroduceUsingDeclarations transform; + string currentNamespace; + HashSet<string> currentMemberTypes; + Dictionary<string, MemberReference> currentMembers; + bool isWithinTypeReferenceExpression; + + public FullyQualifyAmbiguousTypeNamesVisitor(IntroduceUsingDeclarations transform) + { + this.transform = transform; + currentNamespace = transform.context.CurrentType != null ? transform.context.CurrentType.Namespace : string.Empty; + } + + public override object VisitNamespaceDeclaration(NamespaceDeclaration namespaceDeclaration, object data) + { + string oldNamespace = currentNamespace; + foreach (string ident in namespaceDeclaration.Identifiers) { + currentNamespace = NamespaceDeclaration.BuildQualifiedName(currentNamespace, ident); + } + base.VisitNamespaceDeclaration(namespaceDeclaration, data); + currentNamespace = oldNamespace; + return null; + } + + public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data) + { + HashSet<string> oldMemberTypes = currentMemberTypes; + currentMemberTypes = currentMemberTypes != null ? new HashSet<string>(currentMemberTypes) : new HashSet<string>(); + + Dictionary<string, MemberReference> oldMembers = currentMembers; + currentMembers = new Dictionary<string, MemberReference>(); + + TypeDefinition typeDef = typeDeclaration.Annotation<TypeDefinition>(); + bool privateMembersVisible = true; + ModuleDefinition internalMembersVisibleInModule = typeDef.Module; + while (typeDef != null) { + foreach (GenericParameter gp in typeDef.GenericParameters) { + currentMemberTypes.Add(gp.Name); + } + foreach (TypeDefinition t in typeDef.NestedTypes) { + if (privateMembersVisible || IsVisible(t, internalMembersVisibleInModule)) + currentMemberTypes.Add(t.Name.Substring(t.Name.LastIndexOf('+') + 1)); + } + + foreach (MethodDefinition method in typeDef.Methods) { + if (privateMembersVisible || IsVisible(method, internalMembersVisibleInModule)) + AddCurrentMember(method); + } + foreach (PropertyDefinition property in typeDef.Properties) { + if (privateMembersVisible || IsVisible(property.GetMethod, internalMembersVisibleInModule) || IsVisible(property.SetMethod, internalMembersVisibleInModule)) + AddCurrentMember(property); + } + foreach (EventDefinition ev in typeDef.Events) { + if (privateMembersVisible || IsVisible(ev.AddMethod, internalMembersVisibleInModule) || IsVisible(ev.RemoveMethod, internalMembersVisibleInModule)) + AddCurrentMember(ev); + } + foreach (FieldDefinition f in typeDef.Fields) { + if (privateMembersVisible || IsVisible(f, internalMembersVisibleInModule)) + AddCurrentMember(f); + } + // repeat with base class: + if (typeDef.BaseType != null) + typeDef = typeDef.BaseType.Resolve(); + else + typeDef = null; + privateMembersVisible = false; + } + + // Now add current members from outer classes: + if (oldMembers != null) { + foreach (var pair in oldMembers) { + // add members from outer classes only if the inner class doesn't define the member + if (!currentMembers.ContainsKey(pair.Key)) + currentMembers.Add(pair.Key, pair.Value); + } + } + + base.VisitTypeDeclaration(typeDeclaration, data); + currentMembers = oldMembers; + return null; + } + + void AddCurrentMember(MemberReference m) + { + MemberReference existingMember; + if (currentMembers.TryGetValue(m.Name, out existingMember)) { + // We keep the existing member assignment if it was from another class (=from a derived class), + // because members in derived classes have precedence over members in base classes. + if (existingMember != null && existingMember.DeclaringType == m.DeclaringType) { + // Use null as value to signalize multiple members with the same name + currentMembers[m.Name] = null; + } + } else { + currentMembers.Add(m.Name, m); + } + } + + bool IsVisible(MethodDefinition m, ModuleDefinition internalMembersVisibleInModule) + { + if (m == null) + return false; + switch (m.Attributes & MethodAttributes.MemberAccessMask) { + case MethodAttributes.FamANDAssem: + case MethodAttributes.Assembly: + return m.Module == internalMembersVisibleInModule; + case MethodAttributes.Family: + case MethodAttributes.FamORAssem: + case MethodAttributes.Public: + return true; + default: + return false; + } + } + + bool IsVisible(FieldDefinition f, ModuleDefinition internalMembersVisibleInModule) + { + if (f == null) + return false; + switch (f.Attributes & FieldAttributes.FieldAccessMask) { + case FieldAttributes.FamANDAssem: + case FieldAttributes.Assembly: + return f.Module == internalMembersVisibleInModule; + case FieldAttributes.Family: + case FieldAttributes.FamORAssem: + case FieldAttributes.Public: + return true; + default: + return false; + } + } + + bool IsVisible(TypeDefinition t, ModuleDefinition internalMembersVisibleInModule) + { + if (t == null) + return false; + switch (t.Attributes & TypeAttributes.VisibilityMask) { + case TypeAttributes.NotPublic: + case TypeAttributes.NestedAssembly: + case TypeAttributes.NestedFamANDAssem: + return t.Module == internalMembersVisibleInModule; + case TypeAttributes.NestedFamily: + case TypeAttributes.NestedFamORAssem: + case TypeAttributes.NestedPublic: + case TypeAttributes.Public: + return true; + default: + return false; + } + } + + public override object VisitSimpleType(SimpleType simpleType, object data) + { + // Handle type arguments first, so that the fixed-up type arguments get moved over to the MemberType, + // if we're also creating one here. + base.VisitSimpleType(simpleType, data); + TypeReference tr = simpleType.Annotation<TypeReference>(); + // Fully qualify any ambiguous type names. + if (tr != null && IsAmbiguous(tr.Namespace, tr.Name)) { + AstType ns; + if (string.IsNullOrEmpty(tr.Namespace)) { + ns = new SimpleType("global"); + } else { + string[] parts = tr.Namespace.Split('.'); + if (IsAmbiguous(string.Empty, parts[0])) { + // conflict between namespace and type name/member name + ns = new MemberType { Target = new SimpleType("global"), IsDoubleColon = true, MemberName = parts[0] }; + } else { + ns = new SimpleType(parts[0]); + } + for (int i = 1; i < parts.Length; i++) { + ns = new MemberType { Target = ns, MemberName = parts[i] }; + } + } + MemberType mt = new MemberType(); + mt.Target = ns; + mt.IsDoubleColon = string.IsNullOrEmpty(tr.Namespace); + mt.MemberName = simpleType.Identifier; + mt.CopyAnnotationsFrom(simpleType); + simpleType.TypeArguments.MoveTo(mt.TypeArguments); + simpleType.ReplaceWith(mt); + } + return null; + } + + public override object VisitTypeReferenceExpression(TypeReferenceExpression typeReferenceExpression, object data) + { + isWithinTypeReferenceExpression = true; + base.VisitTypeReferenceExpression(typeReferenceExpression, data); + isWithinTypeReferenceExpression = false; + return null; + } + + bool IsAmbiguous(string ns, string name) + { + // If the type name conflicts with an inner class/type parameter, we need to fully-qualify it: + if (currentMemberTypes != null && currentMemberTypes.Contains(name)) + return true; + // If the type name conflicts with a field/property etc. on the current class, we need to fully-qualify it, + // if we're inside an expression. + if (isWithinTypeReferenceExpression && currentMembers != null) { + MemberReference mr; + if (currentMembers.TryGetValue(name, out mr)) { + // However, in the special case where the member is a field or property with the same type + // as is requested, then we can use the short name (if it's not otherwise ambiguous) + PropertyDefinition prop = mr as PropertyDefinition; + FieldDefinition field = mr as FieldDefinition; + if (!(prop != null && prop.PropertyType.Namespace == ns && prop.PropertyType.Name == name) + && !(field != null && field.FieldType.Namespace == ns && field.FieldType.Name == name)) + return true; + } + } + // If the type is defined in the current namespace, + // then we can use the short name even if we imported type with same name from another namespace. + if (ns == currentNamespace && !string.IsNullOrEmpty(ns)) + return false; + return transform.ambiguousTypeNames.Contains(name); + } + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs b/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs new file mode 100644 index 00000000..d3ae46c1 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs @@ -0,0 +1,1123 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +using ICSharpCode.Decompiler.ILAst; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.CSharp.Analysis; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Finds the expanded form of using statements using pattern matching and replaces it with a UsingStatement. + /// </summary> + public sealed class PatternStatementTransform : ContextTrackingVisitor<AstNode>, IAstTransform + { + public PatternStatementTransform(DecompilerContext context) : base(context) + { + } + + #region Visitor Overrides + protected override AstNode VisitChildren(AstNode node, object data) + { + // Go through the children, and keep visiting a node as long as it changes. + // Because some transforms delete/replace nodes before and after the node being transformed, we rely + // on the transform's return value to know where we need to keep iterating. + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + AstNode oldChild; + do { + oldChild = child; + child = child.AcceptVisitor(this, data); + Debug.Assert(child != null && child.Parent == node); + } while (child != oldChild); + } + return node; + } + + public override AstNode VisitExpressionStatement(ExpressionStatement expressionStatement, object data) + { + AstNode result; + if (context.Settings.UsingStatement) { + result = TransformUsings(expressionStatement); + if (result != null) + return result; + result = TransformNonGenericForEach(expressionStatement); + if (result != null) + return result; + } + result = TransformFor(expressionStatement); + if (result != null) + return result; + if (context.Settings.LockStatement) { + result = TransformLock(expressionStatement); + if (result != null) + return result; + } + return base.VisitExpressionStatement(expressionStatement, data); + } + + public override AstNode VisitUsingStatement(UsingStatement usingStatement, object data) + { + if (context.Settings.ForEachStatement) { + AstNode result = TransformForeach(usingStatement); + if (result != null) + return result; + } + return base.VisitUsingStatement(usingStatement, data); + } + + public override AstNode VisitWhileStatement(WhileStatement whileStatement, object data) + { + return TransformDoWhile(whileStatement) ?? base.VisitWhileStatement(whileStatement, data); + } + + public override AstNode VisitIfElseStatement(IfElseStatement ifElseStatement, object data) + { + if (context.Settings.SwitchStatementOnString) { + AstNode result = TransformSwitchOnString(ifElseStatement); + if (result != null) + return result; + } + AstNode simplifiedIfElse = SimplifyCascadingIfElseStatements(ifElseStatement); + if (simplifiedIfElse != null) + return simplifiedIfElse; + return base.VisitIfElseStatement(ifElseStatement, data); + } + + public override AstNode VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration, object data) + { + if (context.Settings.AutomaticProperties) { + AstNode result = TransformAutomaticProperties(propertyDeclaration); + if (result != null) + return result; + } + return base.VisitPropertyDeclaration(propertyDeclaration, data); + } + + public override AstNode VisitCustomEventDeclaration(CustomEventDeclaration eventDeclaration, object data) + { + // first apply transforms to the accessor bodies + base.VisitCustomEventDeclaration(eventDeclaration, data); + if (context.Settings.AutomaticEvents) { + AstNode result = TransformAutomaticEvents(eventDeclaration); + if (result != null) + return result; + } + return eventDeclaration; + } + + public override AstNode VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) + { + return TransformDestructor(methodDeclaration) ?? base.VisitMethodDeclaration(methodDeclaration, data); + } + + public override AstNode VisitTryCatchStatement(TryCatchStatement tryCatchStatement, object data) + { + return TransformTryCatchFinally(tryCatchStatement) ?? base.VisitTryCatchStatement(tryCatchStatement, data); + } + #endregion + + /// <summary> + /// $variable = $initializer; + /// </summary> + static readonly AstNode variableAssignPattern = new ExpressionStatement( + new AssignmentExpression( + new NamedNode("variable", new IdentifierExpression(Pattern.AnyString)), + new AnyNode("initializer") + )); + + #region using + static Expression InvokeDispose(Expression identifier) + { + return new Choice { + identifier.Invoke("Dispose"), + identifier.Clone().CastTo(new TypePattern(typeof(IDisposable))).Invoke("Dispose") + }; + } + + static readonly AstNode usingTryCatchPattern = new TryCatchStatement { + TryBlock = new AnyNode(), + FinallyBlock = new BlockStatement { + new Choice { + { "valueType", + new ExpressionStatement(InvokeDispose(new NamedNode("ident", new IdentifierExpression(Pattern.AnyString)))) + }, + { "referenceType", + new IfElseStatement { + Condition = new BinaryOperatorExpression( + new NamedNode("ident", new IdentifierExpression(Pattern.AnyString)), + BinaryOperatorType.InEquality, + new NullReferenceExpression() + ), + TrueStatement = new BlockStatement { + new ExpressionStatement(InvokeDispose(new Backreference("ident"))) + } + } + } + }.ToStatement() + } + }; + + public UsingStatement TransformUsings(ExpressionStatement node) + { + Match m1 = variableAssignPattern.Match(node); + if (!m1.Success) return null; + TryCatchStatement tryCatch = node.NextSibling as TryCatchStatement; + Match m2 = usingTryCatchPattern.Match(tryCatch); + if (!m2.Success) return null; + string variableName = m1.Get<IdentifierExpression>("variable").Single().Identifier; + if (variableName != m2.Get<IdentifierExpression>("ident").Single().Identifier) + return null; + if (m2.Has("valueType")) { + // if there's no if(x!=null), then it must be a value type + ILVariable v = m1.Get<AstNode>("variable").Single().Annotation<ILVariable>(); + if (v == null || v.Type == null || !v.Type.IsValueType) + return null; + } + + // There are two variants of the using statement: + // "using (var a = init)" and "using (expr)". + // The former declares a read-only variable 'a', and the latter declares an unnamed read-only variable + // to store the original value of 'expr'. + // This means that in order to introduce a using statement, in both cases we need to detect a read-only + // variable that is used only within that block. + + if (HasAssignment(tryCatch, variableName)) + return null; + + VariableDeclarationStatement varDecl = FindVariableDeclaration(node, variableName); + if (varDecl == null || !(varDecl.Parent is BlockStatement)) + return null; + + // Validate that the variable is not used after the using statement: + if (!IsVariableValueUnused(varDecl, tryCatch)) + return null; + + node.Remove(); + + UsingStatement usingStatement = new UsingStatement(); + usingStatement.EmbeddedStatement = tryCatch.TryBlock.Detach(); + tryCatch.ReplaceWith(usingStatement); + + // If possible, we'll eliminate the variable completely: + if (usingStatement.EmbeddedStatement.Descendants.OfType<IdentifierExpression>().Any(ident => ident.Identifier == variableName)) { + // variable is used, so we'll create a variable declaration + usingStatement.ResourceAcquisition = new VariableDeclarationStatement { + Type = (AstType)varDecl.Type.Clone(), + Variables = { + new VariableInitializer { + Name = variableName, + Initializer = m1.Get<Expression>("initializer").Single().Detach() + }.CopyAnnotationsFrom(node.Expression) + .WithAnnotation(m1.Get<AstNode>("variable").Single().Annotation<ILVariable>()) + } + }.CopyAnnotationsFrom(node); + } else { + // the variable is never used; eliminate it: + usingStatement.ResourceAcquisition = m1.Get<Expression>("initializer").Single().Detach(); + } + return usingStatement; + } + + internal static VariableDeclarationStatement FindVariableDeclaration(AstNode node, string identifier) + { + while (node != null) { + while (node.PrevSibling != null) { + node = node.PrevSibling; + VariableDeclarationStatement varDecl = node as VariableDeclarationStatement; + if (varDecl != null && varDecl.Variables.Count == 1 && varDecl.Variables.Single().Name == identifier) { + return varDecl; + } + } + node = node.Parent; + } + return null; + } + + /// <summary> + /// Gets whether the old variable value (assigned inside 'targetStatement' or earlier) + /// is read anywhere in the remaining scope of the variable declaration. + /// </summary> + bool IsVariableValueUnused(VariableDeclarationStatement varDecl, Statement targetStatement) + { + Debug.Assert(targetStatement.Ancestors.Contains(varDecl.Parent)); + BlockStatement block = (BlockStatement)varDecl.Parent; + DefiniteAssignmentAnalysis daa = new DefiniteAssignmentAnalysis(block, context.CancellationToken); + daa.SetAnalyzedRange(targetStatement, block, startInclusive: false); + daa.Analyze(varDecl.Variables.Single().Name); + return daa.UnassignedVariableUses.Count == 0; + } + + // I used this in the first implementation of the using-statement transform, but now no longer + // because there were problems when multiple using statements were using the same variable + // - no single using statement could be transformed without making the C# code invalid, + // but transforming both would work. + // We now use 'IsVariableValueUnused' which will perform the transform + // even if it results in two variables with the same name and overlapping scopes. + // (this issue could be fixed later by renaming one of the variables) + + // I'm not sure whether the other consumers of 'CanMoveVariableDeclarationIntoStatement' should be changed the same way. + bool CanMoveVariableDeclarationIntoStatement(VariableDeclarationStatement varDecl, Statement targetStatement, out Statement declarationPoint) + { + Debug.Assert(targetStatement.Ancestors.Contains(varDecl.Parent)); + // Find all blocks between targetStatement and varDecl.Parent + List<BlockStatement> blocks = targetStatement.Ancestors.TakeWhile(block => block != varDecl.Parent).OfType<BlockStatement>().ToList(); + blocks.Add((BlockStatement)varDecl.Parent); // also handle the varDecl.Parent block itself + blocks.Reverse(); // go from parent blocks to child blocks + DefiniteAssignmentAnalysis daa = new DefiniteAssignmentAnalysis(blocks[0], context.CancellationToken); + declarationPoint = null; + foreach (BlockStatement block in blocks) { + if (!DeclareVariables.FindDeclarationPoint(daa, varDecl, block, out declarationPoint)) { + return false; + } + } + return true; + } + + /// <summary> + /// Gets whether there is an assignment to 'variableName' anywhere within the given node. + /// </summary> + bool HasAssignment(AstNode root, string variableName) + { + foreach (AstNode node in root.DescendantsAndSelf) { + IdentifierExpression ident = node as IdentifierExpression; + if (ident != null && ident.Identifier == variableName) { + if (ident.Parent is AssignmentExpression && ident.Role == AssignmentExpression.LeftRole + || ident.Parent is DirectionExpression) + { + return true; + } + } + } + return false; + } + #endregion + + #region foreach (generic) + static readonly UsingStatement genericForeachPattern = new UsingStatement { + ResourceAcquisition = new VariableDeclarationStatement { + Type = new AnyNode("enumeratorType"), + Variables = { + new NamedNode( + "enumeratorVariable", + new VariableInitializer { + Name = Pattern.AnyString, + Initializer = new AnyNode("collection").ToExpression().Invoke("GetEnumerator") + } + ) + } + }, + EmbeddedStatement = new BlockStatement { + new Repeat( + new VariableDeclarationStatement { Type = new AnyNode(), Variables = { new VariableInitializer(Pattern.AnyString) } }.WithName("variablesOutsideLoop") + ).ToStatement(), + new WhileStatement { + Condition = new IdentifierExpressionBackreference("enumeratorVariable").ToExpression().Invoke("MoveNext"), + EmbeddedStatement = new BlockStatement { + new Repeat( + new VariableDeclarationStatement { + Type = new AnyNode(), + Variables = { new VariableInitializer(Pattern.AnyString) } + }.WithName("variablesInsideLoop") + ).ToStatement(), + new AssignmentExpression { + Left = new IdentifierExpression(Pattern.AnyString).WithName("itemVariable"), + Operator = AssignmentOperatorType.Assign, + Right = new IdentifierExpressionBackreference("enumeratorVariable").ToExpression().Member("Current") + }, + new Repeat(new AnyNode("statement")).ToStatement() + } + }.WithName("loop") + }}; + + public ForeachStatement TransformForeach(UsingStatement node) + { + Match m = genericForeachPattern.Match(node); + if (!m.Success) + return null; + if (!(node.Parent is BlockStatement) && m.Has("variablesOutsideLoop")) { + // if there are variables outside the loop, we need to put those into the parent block, and that won't work if the direct parent isn't a block + return null; + } + VariableInitializer enumeratorVar = m.Get<VariableInitializer>("enumeratorVariable").Single(); + IdentifierExpression itemVar = m.Get<IdentifierExpression>("itemVariable").Single(); + WhileStatement loop = m.Get<WhileStatement>("loop").Single(); + + // Find the declaration of the item variable: + // Because we look only outside the loop, we won't make the mistake of moving a captured variable across the loop boundary + VariableDeclarationStatement itemVarDecl = FindVariableDeclaration(loop, itemVar.Identifier); + if (itemVarDecl == null || !(itemVarDecl.Parent is BlockStatement)) + return null; + + // Now verify that we can move the variable declaration in front of the loop: + Statement declarationPoint; + CanMoveVariableDeclarationIntoStatement(itemVarDecl, loop, out declarationPoint); + // We ignore the return value because we don't care whether we can move the variable into the loop + // (that is possible only with non-captured variables). + // We just care that we can move it in front of the loop: + if (declarationPoint != loop) + return null; + + BlockStatement newBody = new BlockStatement(); + foreach (Statement stmt in m.Get<Statement>("variablesInsideLoop")) + newBody.Add(stmt.Detach()); + foreach (Statement stmt in m.Get<Statement>("statement")) + newBody.Add(stmt.Detach()); + + ForeachStatement foreachStatement = new ForeachStatement { + VariableType = (AstType)itemVarDecl.Type.Clone(), + VariableName = itemVar.Identifier, + InExpression = m.Get<Expression>("collection").Single().Detach(), + EmbeddedStatement = newBody + }.WithAnnotation(itemVarDecl.Variables.Single().Annotation<ILVariable>()); + if (foreachStatement.InExpression is BaseReferenceExpression) { + foreachStatement.InExpression = new ThisReferenceExpression().CopyAnnotationsFrom(foreachStatement.InExpression); + } + node.ReplaceWith(foreachStatement); + foreach (Statement stmt in m.Get<Statement>("variablesOutsideLoop")) { + ((BlockStatement)foreachStatement.Parent).Statements.InsertAfter(null, stmt.Detach()); + } + return foreachStatement; + } + #endregion + + #region foreach (non-generic) + ExpressionStatement getEnumeratorPattern = new ExpressionStatement( + new AssignmentExpression( + new NamedNode("left", new IdentifierExpression(Pattern.AnyString)), + new AnyNode("collection").ToExpression().Invoke("GetEnumerator") + )); + + TryCatchStatement nonGenericForeachPattern = new TryCatchStatement { + TryBlock = new BlockStatement { + new WhileStatement { + Condition = new IdentifierExpression(Pattern.AnyString).WithName("enumerator").Invoke("MoveNext"), + EmbeddedStatement = new BlockStatement { + new AssignmentExpression( + new IdentifierExpression(Pattern.AnyString).WithName("itemVar"), + new Choice { + new Backreference("enumerator").ToExpression().Member("Current"), + new CastExpression { + Type = new AnyNode("castType"), + Expression = new Backreference("enumerator").ToExpression().Member("Current") + } + } + ), + new Repeat(new AnyNode("stmt")).ToStatement() + } + }.WithName("loop") + }, + FinallyBlock = new BlockStatement { + new AssignmentExpression( + new IdentifierExpression(Pattern.AnyString).WithName("disposable"), + new Backreference("enumerator").ToExpression().CastAs(new TypePattern(typeof(IDisposable))) + ), + new IfElseStatement { + Condition = new BinaryOperatorExpression { + Left = new Backreference("disposable"), + Operator = BinaryOperatorType.InEquality, + Right = new NullReferenceExpression() + }, + TrueStatement = new BlockStatement { + new Backreference("disposable").ToExpression().Invoke("Dispose") + } + } + }}; + + public ForeachStatement TransformNonGenericForEach(ExpressionStatement node) + { + Match m1 = getEnumeratorPattern.Match(node); + if (!m1.Success) return null; + AstNode tryCatch = node.NextSibling; + Match m2 = nonGenericForeachPattern.Match(tryCatch); + if (!m2.Success) return null; + + IdentifierExpression enumeratorVar = m2.Get<IdentifierExpression>("enumerator").Single(); + IdentifierExpression itemVar = m2.Get<IdentifierExpression>("itemVar").Single(); + WhileStatement loop = m2.Get<WhileStatement>("loop").Single(); + + // verify that the getEnumeratorPattern assigns to the same variable as the nonGenericForeachPattern is reading from + if (!enumeratorVar.IsMatch(m1.Get("left").Single())) + return null; + + VariableDeclarationStatement enumeratorVarDecl = FindVariableDeclaration(loop, enumeratorVar.Identifier); + if (enumeratorVarDecl == null || !(enumeratorVarDecl.Parent is BlockStatement)) + return null; + + // Find the declaration of the item variable: + // Because we look only outside the loop, we won't make the mistake of moving a captured variable across the loop boundary + VariableDeclarationStatement itemVarDecl = FindVariableDeclaration(loop, itemVar.Identifier); + if (itemVarDecl == null || !(itemVarDecl.Parent is BlockStatement)) + return null; + + // Now verify that we can move the variable declaration in front of the loop: + Statement declarationPoint; + CanMoveVariableDeclarationIntoStatement(itemVarDecl, loop, out declarationPoint); + // We ignore the return value because we don't care whether we can move the variable into the loop + // (that is possible only with non-captured variables). + // We just care that we can move it in front of the loop: + if (declarationPoint != loop) + return null; + + ForeachStatement foreachStatement = new ForeachStatement + { + VariableType = itemVarDecl.Type.Clone(), + VariableName = itemVar.Identifier, + }.WithAnnotation(itemVarDecl.Variables.Single().Annotation<ILVariable>()); + BlockStatement body = new BlockStatement(); + foreachStatement.EmbeddedStatement = body; + ((BlockStatement)node.Parent).Statements.InsertBefore(node, foreachStatement); + + body.Add(node.Detach()); + body.Add((Statement)tryCatch.Detach()); + + // Now that we moved the whole try-catch into the foreach loop; verify that we can + // move the enumerator into the foreach loop: + CanMoveVariableDeclarationIntoStatement(enumeratorVarDecl, foreachStatement, out declarationPoint); + if (declarationPoint != foreachStatement) { + // oops, the enumerator variable can't be moved into the foreach loop + // Undo our AST changes: + ((BlockStatement)foreachStatement.Parent).Statements.InsertBefore(foreachStatement, node.Detach()); + foreachStatement.ReplaceWith(tryCatch); + return null; + } + + // Now create the correct body for the foreach statement: + foreachStatement.InExpression = m1.Get<Expression>("collection").Single().Detach(); + if (foreachStatement.InExpression is BaseReferenceExpression) { + foreachStatement.InExpression = new ThisReferenceExpression().CopyAnnotationsFrom(foreachStatement.InExpression); + } + body.Statements.Clear(); + body.Statements.AddRange(m2.Get<Statement>("stmt").Select(stmt => stmt.Detach())); + + return foreachStatement; + } + #endregion + + #region for + static readonly WhileStatement forPattern = new WhileStatement { + Condition = new BinaryOperatorExpression { + Left = new NamedNode("ident", new IdentifierExpression(Pattern.AnyString)), + Operator = BinaryOperatorType.Any, + Right = new AnyNode("endExpr") + }, + EmbeddedStatement = new BlockStatement { + Statements = { + new Repeat(new AnyNode("statement")), + new NamedNode( + "increment", + new ExpressionStatement( + new AssignmentExpression { + Left = new Backreference("ident"), + Operator = AssignmentOperatorType.Any, + Right = new AnyNode() + })) + } + }}; + + public ForStatement TransformFor(ExpressionStatement node) + { + Match m1 = variableAssignPattern.Match(node); + if (!m1.Success) return null; + AstNode next = node.NextSibling; + Match m2 = forPattern.Match(next); + if (!m2.Success) return null; + // ensure the variable in the for pattern is the same as in the declaration + if (m1.Get<IdentifierExpression>("variable").Single().Identifier != m2.Get<IdentifierExpression>("ident").Single().Identifier) + return null; + WhileStatement loop = (WhileStatement)next; + node.Remove(); + BlockStatement newBody = new BlockStatement(); + foreach (Statement stmt in m2.Get<Statement>("statement")) + newBody.Add(stmt.Detach()); + ForStatement forStatement = new ForStatement(); + forStatement.Initializers.Add(node); + forStatement.Condition = loop.Condition.Detach(); + forStatement.Iterators.Add(m2.Get<Statement>("increment").Single().Detach()); + forStatement.EmbeddedStatement = newBody; + loop.ReplaceWith(forStatement); + return forStatement; + } + #endregion + + #region doWhile + static readonly WhileStatement doWhilePattern = new WhileStatement { + Condition = new PrimitiveExpression(true), + EmbeddedStatement = new BlockStatement { + Statements = { + new Repeat(new AnyNode("statement")), + new IfElseStatement { + Condition = new AnyNode("condition"), + TrueStatement = new BlockStatement { new BreakStatement() } + } + } + }}; + + public DoWhileStatement TransformDoWhile(WhileStatement whileLoop) + { + Match m = doWhilePattern.Match(whileLoop); + if (m.Success) { + DoWhileStatement doLoop = new DoWhileStatement(); + doLoop.Condition = new UnaryOperatorExpression(UnaryOperatorType.Not, m.Get<Expression>("condition").Single().Detach()); + doLoop.Condition.AcceptVisitor(new PushNegation(), null); + BlockStatement block = (BlockStatement)whileLoop.EmbeddedStatement; + block.Statements.Last().Remove(); // remove if statement + doLoop.EmbeddedStatement = block.Detach(); + whileLoop.ReplaceWith(doLoop); + + // we may have to extract variable definitions out of the loop if they were used in the condition: + foreach (var varDecl in block.Statements.OfType<VariableDeclarationStatement>()) { + VariableInitializer v = varDecl.Variables.Single(); + if (doLoop.Condition.DescendantsAndSelf.OfType<IdentifierExpression>().Any(i => i.Identifier == v.Name)) { + AssignmentExpression assign = new AssignmentExpression(new IdentifierExpression(v.Name), v.Initializer.Detach()); + // move annotations from v to assign: + assign.CopyAnnotationsFrom(v); + v.RemoveAnnotations<object>(); + // remove varDecl with assignment; and move annotations from varDecl to the ExpressionStatement: + varDecl.ReplaceWith(new ExpressionStatement(assign).CopyAnnotationsFrom(varDecl)); + varDecl.RemoveAnnotations<object>(); + + // insert the varDecl above the do-while loop: + doLoop.Parent.InsertChildBefore(doLoop, varDecl, BlockStatement.StatementRole); + } + } + return doLoop; + } + return null; + } + #endregion + + #region lock + static readonly AstNode lockFlagInitPattern = new ExpressionStatement( + new AssignmentExpression( + new NamedNode("variable", new IdentifierExpression(Pattern.AnyString)), + new PrimitiveExpression(false) + )); + + static readonly AstNode lockTryCatchPattern = new TryCatchStatement { + TryBlock = new BlockStatement { + new OptionalNode(new VariableDeclarationStatement()).ToStatement(), + new TypePattern(typeof(System.Threading.Monitor)).ToType().Invoke( + "Enter", new AnyNode("enter"), + new DirectionExpression { + FieldDirection = FieldDirection.Ref, + Expression = new NamedNode("flag", new IdentifierExpression(Pattern.AnyString)) + }), + new Repeat(new AnyNode()).ToStatement() + }, + FinallyBlock = new BlockStatement { + new IfElseStatement { + Condition = new Backreference("flag"), + TrueStatement = new BlockStatement { + new TypePattern(typeof(System.Threading.Monitor)).ToType().Invoke("Exit", new AnyNode("exit")) + } + } + }}; + + static readonly AstNode oldMonitorCallPattern = new ExpressionStatement( + new TypePattern(typeof(System.Threading.Monitor)).ToType().Invoke("Enter", new AnyNode("enter")) + ); + + static readonly AstNode oldLockTryCatchPattern = new TryCatchStatement + { + TryBlock = new BlockStatement { + new Repeat(new AnyNode()).ToStatement() + }, + FinallyBlock = new BlockStatement { + new TypePattern(typeof(System.Threading.Monitor)).ToType().Invoke("Exit", new AnyNode("exit")) + } + }; + + bool AnalyzeLockV2(ExpressionStatement node, out Expression enter, out Expression exit) + { + enter = null; + exit = null; + Match m1 = oldMonitorCallPattern.Match(node); + if (!m1.Success) return false; + Match m2 = oldLockTryCatchPattern.Match(node.NextSibling); + if (!m2.Success) return false; + enter = m1.Get<Expression>("enter").Single(); + exit = m2.Get<Expression>("exit").Single(); + return true; + } + + bool AnalyzeLockV4(ExpressionStatement node, out Expression enter, out Expression exit) + { + enter = null; + exit = null; + Match m1 = lockFlagInitPattern.Match(node); + if (!m1.Success) return false; + Match m2 = lockTryCatchPattern.Match(node.NextSibling); + if (!m2.Success) return false; + enter = m2.Get<Expression>("enter").Single(); + exit = m2.Get<Expression>("exit").Single(); + return m1.Get<IdentifierExpression>("variable").Single().Identifier == m2.Get<IdentifierExpression>("flag").Single().Identifier; + } + + public LockStatement TransformLock(ExpressionStatement node) + { + Expression enter, exit; + bool isV2 = AnalyzeLockV2(node, out enter, out exit); + if (isV2 || AnalyzeLockV4(node, out enter, out exit)) { + AstNode tryCatch = node.NextSibling; + if (!exit.IsMatch(enter)) { + // If exit and enter are not the same, then enter must be "exit = ..." + AssignmentExpression assign = enter as AssignmentExpression; + if (assign == null) + return null; + if (!exit.IsMatch(assign.Left)) + return null; + enter = assign.Right; + // TODO: verify that 'obj' variable can be removed + } + // TODO: verify that 'flag' variable can be removed + // transform the code into a lock statement: + LockStatement l = new LockStatement(); + l.Expression = enter.Detach(); + l.EmbeddedStatement = ((TryCatchStatement)tryCatch).TryBlock.Detach(); + if (!isV2) // Remove 'Enter()' call + ((BlockStatement)l.EmbeddedStatement).Statements.First().Remove(); + tryCatch.ReplaceWith(l); + node.Remove(); // remove flag variable + return l; + } + return null; + } + #endregion + + #region switch on strings + static readonly IfElseStatement switchOnStringPattern = new IfElseStatement { + Condition = new BinaryOperatorExpression { + Left = new AnyNode("switchExpr"), + Operator = BinaryOperatorType.InEquality, + Right = new NullReferenceExpression() + }, + TrueStatement = new BlockStatement { + new IfElseStatement { + Condition = new BinaryOperatorExpression { + Left = new AnyNode("cachedDict"), + Operator = BinaryOperatorType.Equality, + Right = new NullReferenceExpression() + }, + TrueStatement = new AnyNode("dictCreation") + }, + new IfElseStatement { + Condition = new Backreference("cachedDict").ToExpression().Invoke( + "TryGetValue", + new NamedNode("switchVar", new IdentifierExpression(Pattern.AnyString)), + new DirectionExpression { + FieldDirection = FieldDirection.Out, + Expression = new IdentifierExpression(Pattern.AnyString).WithName("intVar") + }), + TrueStatement = new BlockStatement { + Statements = { + new NamedNode( + "switch", new SwitchStatement { + Expression = new IdentifierExpressionBackreference("intVar"), + SwitchSections = { new Repeat(new AnyNode()) } + }) + } + } + }, + new Repeat(new AnyNode("nonNullDefaultStmt")).ToStatement() + }, + FalseStatement = new OptionalNode("nullStmt", new BlockStatement { Statements = { new Repeat(new AnyNode()) } }) + }; + + public SwitchStatement TransformSwitchOnString(IfElseStatement node) + { + Match m = switchOnStringPattern.Match(node); + if (!m.Success) + return null; + // switchVar must be the same as switchExpr; or switchExpr must be an assignment and switchVar the left side of that assignment + if (!m.Get("switchVar").Single().IsMatch(m.Get("switchExpr").Single())) { + AssignmentExpression assign = m.Get("switchExpr").Single() as AssignmentExpression; + if (!(assign != null && m.Get("switchVar").Single().IsMatch(assign.Left))) + return null; + } + FieldReference cachedDictField = m.Get<AstNode>("cachedDict").Single().Annotation<FieldReference>(); + if (cachedDictField == null) + return null; + List<Statement> dictCreation = m.Get<BlockStatement>("dictCreation").Single().Statements.ToList(); + List<KeyValuePair<string, int>> dict = BuildDictionary(dictCreation); + SwitchStatement sw = m.Get<SwitchStatement>("switch").Single(); + sw.Expression = m.Get<Expression>("switchExpr").Single().Detach(); + foreach (SwitchSection section in sw.SwitchSections) { + List<CaseLabel> labels = section.CaseLabels.ToList(); + section.CaseLabels.Clear(); + foreach (CaseLabel label in labels) { + PrimitiveExpression expr = label.Expression as PrimitiveExpression; + if (expr == null || !(expr.Value is int)) + continue; + int val = (int)expr.Value; + foreach (var pair in dict) { + if (pair.Value == val) + section.CaseLabels.Add(new CaseLabel { Expression = new PrimitiveExpression(pair.Key) }); + } + } + } + if (m.Has("nullStmt")) { + SwitchSection section = new SwitchSection(); + section.CaseLabels.Add(new CaseLabel { Expression = new NullReferenceExpression() }); + BlockStatement block = m.Get<BlockStatement>("nullStmt").Single(); + block.Statements.Add(new BreakStatement()); + section.Statements.Add(block.Detach()); + sw.SwitchSections.Add(section); + } else if (m.Has("nonNullDefaultStmt")) { + sw.SwitchSections.Add( + new SwitchSection { + CaseLabels = { new CaseLabel { Expression = new NullReferenceExpression() } }, + Statements = { new BlockStatement { new BreakStatement() } } + }); + } + if (m.Has("nonNullDefaultStmt")) { + SwitchSection section = new SwitchSection(); + section.CaseLabels.Add(new CaseLabel()); + BlockStatement block = new BlockStatement(); + block.Statements.AddRange(m.Get<Statement>("nonNullDefaultStmt").Select(s => s.Detach())); + block.Add(new BreakStatement()); + section.Statements.Add(block); + sw.SwitchSections.Add(section); + } + node.ReplaceWith(sw); + return sw; + } + + List<KeyValuePair<string, int>> BuildDictionary(List<Statement> dictCreation) + { + if (context.Settings.ObjectOrCollectionInitializers && dictCreation.Count == 1) + return BuildDictionaryFromInitializer(dictCreation[0]); + + return BuildDictionaryFromAddMethodCalls(dictCreation); + } + + static readonly Statement assignInitializedDictionary = new ExpressionStatement { + Expression = new AssignmentExpression { + Left = new AnyNode().ToExpression(), + Right = new ObjectCreateExpression { + Type = new AnyNode(), + Arguments = { new Repeat(new AnyNode()) }, + Initializer = new ArrayInitializerExpression { + Elements = { new Repeat(new AnyNode("dictJumpTable")) } + } + }, + }, + }; + + List<KeyValuePair<string, int>> BuildDictionaryFromInitializer(Statement statement) + { + List<KeyValuePair<string, int>> dict = new List<KeyValuePair<string, int>>(); + Match m = assignInitializedDictionary.Match(statement); + if (!m.Success) + return dict; + + foreach (ArrayInitializerExpression initializer in m.Get<ArrayInitializerExpression>("dictJumpTable")) { + KeyValuePair<string, int> pair; + if (TryGetPairFrom(initializer.Elements, out pair)) + dict.Add(pair); + } + + return dict; + } + + static List<KeyValuePair<string, int>> BuildDictionaryFromAddMethodCalls(List<Statement> dictCreation) + { + List<KeyValuePair<string, int>> dict = new List<KeyValuePair<string, int>>(); + for (int i = 0; i < dictCreation.Count; i++) { + ExpressionStatement es = dictCreation[i] as ExpressionStatement; + if (es == null) + continue; + InvocationExpression ie = es.Expression as InvocationExpression; + if (ie == null) + continue; + + KeyValuePair<string, int> pair; + if (TryGetPairFrom(ie.Arguments, out pair)) + dict.Add(pair); + } + return dict; + } + + static bool TryGetPairFrom(AstNodeCollection<Expression> expressions, out KeyValuePair<string, int> pair) + { + PrimitiveExpression arg1 = expressions.ElementAtOrDefault(0) as PrimitiveExpression; + PrimitiveExpression arg2 = expressions.ElementAtOrDefault(1) as PrimitiveExpression; + if (arg1 != null && arg2 != null && arg1.Value is string && arg2.Value is int) { + pair = new KeyValuePair<string, int>((string)arg1.Value, (int)arg2.Value); + return true; + } + + pair = default(KeyValuePair<string, int>); + return false; + } + + #endregion + + #region Automatic Properties + static readonly PropertyDeclaration automaticPropertyPattern = new PropertyDeclaration { + Attributes = { new Repeat(new AnyNode()) }, + Modifiers = Modifiers.Any, + ReturnType = new AnyNode(), + PrivateImplementationType = new OptionalNode(new AnyNode()), + Name = Pattern.AnyString, + Getter = new Accessor { + Attributes = { new Repeat(new AnyNode()) }, + Modifiers = Modifiers.Any, + Body = new BlockStatement { + new ReturnStatement { + Expression = new AnyNode("fieldReference") + } + } + }, + Setter = new Accessor { + Attributes = { new Repeat(new AnyNode()) }, + Modifiers = Modifiers.Any, + Body = new BlockStatement { + new AssignmentExpression { + Left = new Backreference("fieldReference"), + Right = new IdentifierExpression("value") + } + }}}; + + PropertyDeclaration TransformAutomaticProperties(PropertyDeclaration property) + { + PropertyDefinition cecilProperty = property.Annotation<PropertyDefinition>(); + if (cecilProperty == null || cecilProperty.GetMethod == null || cecilProperty.SetMethod == null) + return null; + if (!(cecilProperty.GetMethod.IsCompilerGenerated() && cecilProperty.SetMethod.IsCompilerGenerated())) + return null; + Match m = automaticPropertyPattern.Match(property); + if (m.Success) { + FieldDefinition field = m.Get<AstNode>("fieldReference").Single().Annotation<FieldReference>().ResolveWithinSameModule(); + if (field.IsCompilerGenerated() && field.DeclaringType == cecilProperty.DeclaringType) { + RemoveCompilerGeneratedAttribute(property.Getter.Attributes); + RemoveCompilerGeneratedAttribute(property.Setter.Attributes); + property.Getter.Body = null; + property.Setter.Body = null; + } + } + // Since the event instance is not changed, we can continue in the visitor as usual, so return null + return null; + } + + void RemoveCompilerGeneratedAttribute(AstNodeCollection<AttributeSection> attributeSections) + { + foreach (AttributeSection section in attributeSections) { + foreach (var attr in section.Attributes) { + TypeReference tr = attr.Type.Annotation<TypeReference>(); + if (tr != null && tr.Namespace == "System.Runtime.CompilerServices" && tr.Name == "CompilerGeneratedAttribute") { + attr.Remove(); + } + } + if (section.Attributes.Count == 0) + section.Remove(); + } + } + #endregion + + #region Automatic Events + static readonly Accessor automaticEventPatternV4 = new Accessor { + Attributes = { new Repeat(new AnyNode()) }, + Body = new BlockStatement { + new VariableDeclarationStatement { Type = new AnyNode("type"), Variables = { new AnyNode() } }, + new VariableDeclarationStatement { Type = new Backreference("type"), Variables = { new AnyNode() } }, + new VariableDeclarationStatement { Type = new Backreference("type"), Variables = { new AnyNode() } }, + new AssignmentExpression { + Left = new NamedNode("var1", new IdentifierExpression(Pattern.AnyString)), + Operator = AssignmentOperatorType.Assign, + Right = new NamedNode( + "field", + new MemberReferenceExpression { + Target = new Choice { new ThisReferenceExpression(), new TypeReferenceExpression { Type = new AnyNode() } }, + MemberName = Pattern.AnyString + }) + }, + new DoWhileStatement { + EmbeddedStatement = new BlockStatement { + new AssignmentExpression(new NamedNode("var2", new IdentifierExpression(Pattern.AnyString)), new IdentifierExpressionBackreference("var1")), + new AssignmentExpression { + Left = new NamedNode("var3", new IdentifierExpression(Pattern.AnyString)), + Operator = AssignmentOperatorType.Assign, + Right = new AnyNode("delegateCombine").ToExpression().Invoke( + new IdentifierExpressionBackreference("var2"), + new IdentifierExpression("value") + ).CastTo(new Backreference("type")) + }, + new AssignmentExpression { + Left = new IdentifierExpressionBackreference("var1"), + Right = new TypePattern(typeof(System.Threading.Interlocked)).ToType().Invoke( + "CompareExchange", + new AstType[] { new Backreference("type") }, // type argument + new Expression[] { // arguments + new DirectionExpression { FieldDirection = FieldDirection.Ref, Expression = new Backreference("field") }, + new IdentifierExpressionBackreference("var3"), + new IdentifierExpressionBackreference("var2") + } + )} + }, + Condition = new BinaryOperatorExpression { + Left = new IdentifierExpressionBackreference("var1"), + Operator = BinaryOperatorType.InEquality, + Right = new IdentifierExpressionBackreference("var2") + }} + }}; + + bool CheckAutomaticEventV4Match(Match m, CustomEventDeclaration ev, bool isAddAccessor) + { + if (!m.Success) + return false; + if (m.Get<MemberReferenceExpression>("field").Single().MemberName != ev.Name) + return false; // field name must match event name + if (!ev.ReturnType.IsMatch(m.Get("type").Single())) + return false; // variable types must match event type + var combineMethod = m.Get<AstNode>("delegateCombine").Single().Parent.Annotation<MethodReference>(); + if (combineMethod == null || combineMethod.Name != (isAddAccessor ? "Combine" : "Remove")) + return false; + return combineMethod.DeclaringType.FullName == "System.Delegate"; + } + + EventDeclaration TransformAutomaticEvents(CustomEventDeclaration ev) + { + Match m1 = automaticEventPatternV4.Match(ev.AddAccessor); + if (!CheckAutomaticEventV4Match(m1, ev, true)) + return null; + Match m2 = automaticEventPatternV4.Match(ev.RemoveAccessor); + if (!CheckAutomaticEventV4Match(m2, ev, false)) + return null; + EventDeclaration ed = new EventDeclaration(); + ev.Attributes.MoveTo(ed.Attributes); + foreach (var attr in ev.AddAccessor.Attributes) { + attr.AttributeTarget = "method"; + ed.Attributes.Add(attr.Detach()); + } + ed.ReturnType = ev.ReturnType.Detach(); + ed.Modifiers = ev.Modifiers; + ed.Variables.Add(new VariableInitializer(ev.Name)); + ed.CopyAnnotationsFrom(ev); + + EventDefinition eventDef = ev.Annotation<EventDefinition>(); + if (eventDef != null) { + FieldDefinition field = eventDef.DeclaringType.Fields.FirstOrDefault(f => f.Name == ev.Name); + if (field != null) { + ed.AddAnnotation(field); + AstBuilder.ConvertAttributes(ed, field, "field"); + } + } + + ev.ReplaceWith(ed); + return ed; + } + #endregion + + #region Destructor + static readonly MethodDeclaration destructorPattern = new MethodDeclaration { + Attributes = { new Repeat(new AnyNode()) }, + Modifiers = Modifiers.Any, + ReturnType = new PrimitiveType("void"), + Name = "Finalize", + Body = new BlockStatement { + new TryCatchStatement { + TryBlock = new AnyNode("body"), + FinallyBlock = new BlockStatement { + new BaseReferenceExpression().Invoke("Finalize") + } + } + } + }; + + DestructorDeclaration TransformDestructor(MethodDeclaration methodDef) + { + Match m = destructorPattern.Match(methodDef); + if (m.Success) { + DestructorDeclaration dd = new DestructorDeclaration(); + methodDef.Attributes.MoveTo(dd.Attributes); + dd.Modifiers = methodDef.Modifiers & ~(Modifiers.Protected | Modifiers.Override); + dd.Body = m.Get<BlockStatement>("body").Single().Detach(); + dd.Name = AstBuilder.CleanName(context.CurrentType.Name); + methodDef.ReplaceWith(dd); + return dd; + } + return null; + } + #endregion + + #region Try-Catch-Finally + static readonly TryCatchStatement tryCatchFinallyPattern = new TryCatchStatement { + TryBlock = new BlockStatement { + new TryCatchStatement { + TryBlock = new AnyNode(), + CatchClauses = { new Repeat(new AnyNode()) } + } + }, + FinallyBlock = new AnyNode() + }; + + /// <summary> + /// Simplify nested 'try { try {} catch {} } finally {}'. + /// This transformation must run after the using/lock tranformations. + /// </summary> + TryCatchStatement TransformTryCatchFinally(TryCatchStatement tryFinally) + { + if (tryCatchFinallyPattern.IsMatch(tryFinally)) { + TryCatchStatement tryCatch = (TryCatchStatement)tryFinally.TryBlock.Statements.Single(); + tryFinally.TryBlock = tryCatch.TryBlock.Detach(); + tryCatch.CatchClauses.MoveTo(tryFinally.CatchClauses); + } + // Since the tryFinally instance is not changed, we can continue in the visitor as usual, so return null + return null; + } + #endregion + + #region Simplify cascading if-else-if statements + static readonly IfElseStatement cascadingIfElsePattern = new IfElseStatement + { + Condition = new AnyNode(), + TrueStatement = new AnyNode(), + FalseStatement = new BlockStatement { + Statements = { + new NamedNode( + "nestedIfStatement", + new IfElseStatement { + Condition = new AnyNode(), + TrueStatement = new AnyNode(), + FalseStatement = new OptionalNode(new AnyNode()) + } + ) + } + } + }; + + AstNode SimplifyCascadingIfElseStatements(IfElseStatement node) + { + Match m = cascadingIfElsePattern.Match(node); + if (m.Success) { + IfElseStatement elseIf = m.Get<IfElseStatement>("nestedIfStatement").Single(); + node.FalseStatement = elseIf.Detach(); + } + + return null; + } + #endregion + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/PushNegation.cs b/ICSharpCode.Decompiler/Ast/Transforms/PushNegation.cs new file mode 100644 index 00000000..193c5e69 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/PushNegation.cs @@ -0,0 +1,164 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.PatternMatching; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + public class PushNegation: DepthFirstAstVisitor<object, object>, IAstTransform + { + sealed class LiftedOperator { } + /// <summary> + /// Annotation for lifted operators that cannot be transformed by PushNegation + /// </summary> + public static readonly object LiftedOperatorAnnotation = new LiftedOperator(); + + public override object VisitUnaryOperatorExpression(UnaryOperatorExpression unary, object data) + { + // lifted operators can't be transformed + if (unary.Annotation<LiftedOperator>() != null || unary.Expression.Annotation<LiftedOperator>() != null) + return base.VisitUnaryOperatorExpression(unary, data); + + // Remove double negation + // !!a + if (unary.Operator == UnaryOperatorType.Not && + unary.Expression is UnaryOperatorExpression && + (unary.Expression as UnaryOperatorExpression).Operator == UnaryOperatorType.Not) + { + AstNode newNode = (unary.Expression as UnaryOperatorExpression).Expression; + unary.ReplaceWith(newNode); + return newNode.AcceptVisitor(this, data); + } + + // Push through binary operation + // !((a) op (b)) + BinaryOperatorExpression binaryOp = unary.Expression as BinaryOperatorExpression; + if (unary.Operator == UnaryOperatorType.Not && binaryOp != null) { + bool successful = true; + switch (binaryOp.Operator) { + case BinaryOperatorType.Equality: + binaryOp.Operator = BinaryOperatorType.InEquality; + break; + case BinaryOperatorType.InEquality: + binaryOp.Operator = BinaryOperatorType.Equality; + break; + case BinaryOperatorType.GreaterThan: // TODO: these are invalid for floats (stupid NaN) + binaryOp.Operator = BinaryOperatorType.LessThanOrEqual; + break; + case BinaryOperatorType.GreaterThanOrEqual: + binaryOp.Operator = BinaryOperatorType.LessThan; + break; + case BinaryOperatorType.LessThanOrEqual: + binaryOp.Operator = BinaryOperatorType.GreaterThan; + break; + case BinaryOperatorType.LessThan: + binaryOp.Operator = BinaryOperatorType.GreaterThanOrEqual; + break; + default: + successful = false; + break; + } + if (successful) { + unary.ReplaceWith(binaryOp); + return binaryOp.AcceptVisitor(this, data); + } + + successful = true; + switch (binaryOp.Operator) { + case BinaryOperatorType.ConditionalAnd: + binaryOp.Operator = BinaryOperatorType.ConditionalOr; + break; + case BinaryOperatorType.ConditionalOr: + binaryOp.Operator = BinaryOperatorType.ConditionalAnd; + break; + default: + successful = false; + break; + } + if (successful) { + binaryOp.Left.ReplaceWith(e => new UnaryOperatorExpression(UnaryOperatorType.Not, e)); + binaryOp.Right.ReplaceWith(e => new UnaryOperatorExpression(UnaryOperatorType.Not, e)); + unary.ReplaceWith(binaryOp); + return binaryOp.AcceptVisitor(this, data); + } + } + return base.VisitUnaryOperatorExpression(unary, data); + } + + readonly static AstNode asCastIsNullPattern = new BinaryOperatorExpression( + new AnyNode("expr").ToExpression().CastAs(new AnyNode("type")), + BinaryOperatorType.Equality, + new NullReferenceExpression() + ); + + readonly static AstNode asCastIsNotNullPattern = new BinaryOperatorExpression( + new AnyNode("expr").ToExpression().CastAs(new AnyNode("type")), + BinaryOperatorType.InEquality, + new NullReferenceExpression() + ); + + public override object VisitBinaryOperatorExpression(BinaryOperatorExpression binaryOperatorExpression, object data) + { + // lifted operators can't be transformed + if (binaryOperatorExpression.Annotation<LiftedOperator>() != null) + return base.VisitBinaryOperatorExpression(binaryOperatorExpression, data); + + BinaryOperatorType op = binaryOperatorExpression.Operator; + bool? rightOperand = null; + if (binaryOperatorExpression.Right is PrimitiveExpression) + rightOperand = ((PrimitiveExpression)binaryOperatorExpression.Right).Value as bool?; + if (op == BinaryOperatorType.Equality && rightOperand == true || op == BinaryOperatorType.InEquality && rightOperand == false) { + // 'b == true' or 'b != false' is useless + binaryOperatorExpression.Left.AcceptVisitor(this, data); + binaryOperatorExpression.ReplaceWith(binaryOperatorExpression.Left); + return null; + } else if (op == BinaryOperatorType.Equality && rightOperand == false || op == BinaryOperatorType.InEquality && rightOperand == true) { + // 'b == false' or 'b != true' is a negation: + Expression left = binaryOperatorExpression.Left; + left.Remove(); + UnaryOperatorExpression uoe = new UnaryOperatorExpression(UnaryOperatorType.Not, left); + binaryOperatorExpression.ReplaceWith(uoe); + return uoe.AcceptVisitor(this, data); + } else { + bool negate = false; + Match m = asCastIsNotNullPattern.Match(binaryOperatorExpression); + if (!m.Success) { + m = asCastIsNullPattern.Match(binaryOperatorExpression); + negate = true; + } + if (m.Success) { + Expression expr = m.Get<Expression>("expr").Single().Detach().IsType(m.Get<AstType>("type").Single().Detach()); + if (negate) + expr = new UnaryOperatorExpression(UnaryOperatorType.Not, expr); + binaryOperatorExpression.ReplaceWith(expr); + return expr.AcceptVisitor(this, data); + } else { + return base.VisitBinaryOperatorExpression(binaryOperatorExpression, data); + } + } + } + void IAstTransform.Run(AstNode node) + { + node.AcceptVisitor(this, null); + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ReplaceMethodCallsWithOperators.cs b/ICSharpCode.Decompiler/Ast/Transforms/ReplaceMethodCallsWithOperators.cs new file mode 100644 index 00000000..6a3f8f97 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/ReplaceMethodCallsWithOperators.cs @@ -0,0 +1,356 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using ICSharpCode.NRefactory.PatternMatching; +using Mono.Cecil; +using Ast = ICSharpCode.NRefactory.CSharp; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// <summary> + /// Replaces method calls with the appropriate operator expressions. + /// Also simplifies "x = x op y" into "x op= y" where possible. + /// </summary> + public class ReplaceMethodCallsWithOperators : DepthFirstAstVisitor<object, object>, IAstTransform + { + static readonly MemberReferenceExpression typeHandleOnTypeOfPattern = new MemberReferenceExpression { + Target = new Choice { + new TypeOfExpression(new AnyNode()), + new UndocumentedExpression { UndocumentedExpressionType = UndocumentedExpressionType.RefType, Arguments = { new AnyNode() } } + }, + MemberName = "TypeHandle" + }; + + DecompilerContext context; + + public ReplaceMethodCallsWithOperators(DecompilerContext context) + { + this.context = context; + } + + public override object VisitInvocationExpression(InvocationExpression invocationExpression, object data) + { + base.VisitInvocationExpression(invocationExpression, data); + ProcessInvocationExpression(invocationExpression); + return null; + } + + internal static void ProcessInvocationExpression(InvocationExpression invocationExpression) + { + MethodReference methodRef = invocationExpression.Annotation<MethodReference>(); + if (methodRef == null) + return; + var arguments = invocationExpression.Arguments.ToArray(); + + // Reduce "String.Concat(a, b)" to "a + b" + if (methodRef.Name == "Concat" && methodRef.DeclaringType.FullName == "System.String" && arguments.Length >= 2) + { + invocationExpression.Arguments.Clear(); // detach arguments from invocationExpression + Expression expr = arguments[0]; + for (int i = 1; i < arguments.Length; i++) { + expr = new BinaryOperatorExpression(expr, BinaryOperatorType.Add, arguments[i]); + } + invocationExpression.ReplaceWith(expr); + return; + } + + switch (methodRef.FullName) { + case "System.Type System.Type::GetTypeFromHandle(System.RuntimeTypeHandle)": + if (arguments.Length == 1) { + if (typeHandleOnTypeOfPattern.IsMatch(arguments[0])) { + invocationExpression.ReplaceWith(((MemberReferenceExpression)arguments[0]).Target); + return; + } + } + break; + case "System.Reflection.FieldInfo System.Reflection.FieldInfo::GetFieldFromHandle(System.RuntimeFieldHandle)": + if (arguments.Length == 1) { + MemberReferenceExpression mre = arguments[0] as MemberReferenceExpression; + if (mre != null && mre.MemberName == "FieldHandle" && mre.Target.Annotation<LdTokenAnnotation>() != null) { + invocationExpression.ReplaceWith(mre.Target); + return; + } + } + break; + case "System.Reflection.FieldInfo System.Reflection.FieldInfo::GetFieldFromHandle(System.RuntimeFieldHandle,System.RuntimeTypeHandle)": + if (arguments.Length == 2) { + MemberReferenceExpression mre1 = arguments[0] as MemberReferenceExpression; + MemberReferenceExpression mre2 = arguments[1] as MemberReferenceExpression; + if (mre1 != null && mre1.MemberName == "FieldHandle" && mre1.Target.Annotation<LdTokenAnnotation>() != null) { + if (mre2 != null && mre2.MemberName == "TypeHandle" && mre2.Target is TypeOfExpression) { + Expression oldArg = ((InvocationExpression)mre1.Target).Arguments.Single(); + FieldReference field = oldArg.Annotation<FieldReference>(); + if (field != null) { + AstType declaringType = ((TypeOfExpression)mre2.Target).Type.Detach(); + oldArg.ReplaceWith(declaringType.Member(field.Name).WithAnnotation(field)); + invocationExpression.ReplaceWith(mre1.Target); + return; + } + } + } + } + break; + } + + BinaryOperatorType? bop = GetBinaryOperatorTypeFromMetadataName(methodRef.Name); + if (bop != null && arguments.Length == 2) { + invocationExpression.Arguments.Clear(); // detach arguments from invocationExpression + invocationExpression.ReplaceWith( + new BinaryOperatorExpression(arguments[0], bop.Value, arguments[1]).WithAnnotation(methodRef) + ); + return; + } + UnaryOperatorType? uop = GetUnaryOperatorTypeFromMetadataName(methodRef.Name); + if (uop != null && arguments.Length == 1) { + arguments[0].Remove(); // detach argument + invocationExpression.ReplaceWith( + new UnaryOperatorExpression(uop.Value, arguments[0]).WithAnnotation(methodRef) + ); + return; + } + if (methodRef.Name == "op_Explicit" && arguments.Length == 1) { + arguments[0].Remove(); // detach argument + invocationExpression.ReplaceWith( + arguments[0].CastTo(AstBuilder.ConvertType(methodRef.ReturnType, methodRef.MethodReturnType)) + .WithAnnotation(methodRef) + ); + return; + } + if (methodRef.Name == "op_Implicit" && arguments.Length == 1) { + invocationExpression.ReplaceWith(arguments[0]); + return; + } + if (methodRef.Name == "op_True" && arguments.Length == 1 && invocationExpression.Role == Roles.Condition) { + invocationExpression.ReplaceWith(arguments[0]); + return; + } + + return; + } + + static BinaryOperatorType? GetBinaryOperatorTypeFromMetadataName(string name) + { + switch (name) { + case "op_Addition": + return BinaryOperatorType.Add; + case "op_Subtraction": + return BinaryOperatorType.Subtract; + case "op_Multiply": + return BinaryOperatorType.Multiply; + case "op_Division": + return BinaryOperatorType.Divide; + case "op_Modulus": + return BinaryOperatorType.Modulus; + case "op_BitwiseAnd": + return BinaryOperatorType.BitwiseAnd; + case "op_BitwiseOr": + return BinaryOperatorType.BitwiseOr; + case "op_ExclusiveOr": + return BinaryOperatorType.ExclusiveOr; + case "op_LeftShift": + return BinaryOperatorType.ShiftLeft; + case "op_RightShift": + return BinaryOperatorType.ShiftRight; + case "op_Equality": + return BinaryOperatorType.Equality; + case "op_Inequality": + return BinaryOperatorType.InEquality; + case "op_LessThan": + return BinaryOperatorType.LessThan; + case "op_LessThanOrEqual": + return BinaryOperatorType.LessThanOrEqual; + case "op_GreaterThan": + return BinaryOperatorType.GreaterThan; + case "op_GreaterThanOrEqual": + return BinaryOperatorType.GreaterThanOrEqual; + default: + return null; + } + } + + static UnaryOperatorType? GetUnaryOperatorTypeFromMetadataName(string name) + { + switch (name) { + case "op_LogicalNot": + return UnaryOperatorType.Not; + case "op_OnesComplement": + return UnaryOperatorType.BitNot; + case "op_UnaryNegation": + return UnaryOperatorType.Minus; + case "op_UnaryPlus": + return UnaryOperatorType.Plus; + case "op_Increment": + return UnaryOperatorType.Increment; + case "op_Decrement": + return UnaryOperatorType.Decrement; + default: + return null; + } + } + + /// <summary> + /// This annotation is used to convert a compound assignment "a += 2;" or increment operator "a++;" + /// back to the original "a = a + 2;". This is sometimes necessary when the checked/unchecked semantics + /// cannot be guaranteed otherwise (see CheckedUnchecked.ForWithCheckedInitializerAndUncheckedIterator test) + /// </summary> + public class RestoreOriginalAssignOperatorAnnotation + { + readonly BinaryOperatorExpression binaryOperatorExpression; + + public RestoreOriginalAssignOperatorAnnotation(BinaryOperatorExpression binaryOperatorExpression) + { + this.binaryOperatorExpression = binaryOperatorExpression; + } + + public AssignmentExpression Restore(Expression expression) + { + expression.RemoveAnnotations<RestoreOriginalAssignOperatorAnnotation>(); + AssignmentExpression assign = expression as AssignmentExpression; + if (assign == null) { + UnaryOperatorExpression uoe = (UnaryOperatorExpression)expression; + assign = new AssignmentExpression(uoe.Expression.Detach(), new PrimitiveExpression(1)); + } else { + assign.Operator = AssignmentOperatorType.Assign; + } + binaryOperatorExpression.Right = assign.Right.Detach(); + assign.Right = binaryOperatorExpression; + return assign; + } + } + + public override object VisitAssignmentExpression(AssignmentExpression assignment, object data) + { + base.VisitAssignmentExpression(assignment, data); + // Combine "x = x op y" into "x op= y" + BinaryOperatorExpression binary = assignment.Right as BinaryOperatorExpression; + if (binary != null && assignment.Operator == AssignmentOperatorType.Assign) { + if (CanConvertToCompoundAssignment(assignment.Left) && assignment.Left.IsMatch(binary.Left)) { + assignment.Operator = GetAssignmentOperatorForBinaryOperator(binary.Operator); + if (assignment.Operator != AssignmentOperatorType.Assign) { + // If we found a shorter operator, get rid of the BinaryOperatorExpression: + assignment.CopyAnnotationsFrom(binary); + assignment.Right = binary.Right; + assignment.AddAnnotation(new RestoreOriginalAssignOperatorAnnotation(binary)); + } + } + } + if (context.Settings.IntroduceIncrementAndDecrement && (assignment.Operator == AssignmentOperatorType.Add || assignment.Operator == AssignmentOperatorType.Subtract)) { + // detect increment/decrement + if (assignment.Right.IsMatch(new PrimitiveExpression(1))) { + // only if it's not a custom operator + if (assignment.Annotation<MethodReference>() == null) { + UnaryOperatorType type; + // When the parent is an expression statement, pre- or post-increment doesn't matter; + // so we can pick post-increment which is more commonly used (for (int i = 0; i < x; i++)) + if (assignment.Parent is ExpressionStatement) + type = (assignment.Operator == AssignmentOperatorType.Add) ? UnaryOperatorType.PostIncrement : UnaryOperatorType.PostDecrement; + else + type = (assignment.Operator == AssignmentOperatorType.Add) ? UnaryOperatorType.Increment : UnaryOperatorType.Decrement; + assignment.ReplaceWith(new UnaryOperatorExpression(type, assignment.Left.Detach()).CopyAnnotationsFrom(assignment)); + } + } + } + return null; + } + + public static AssignmentOperatorType GetAssignmentOperatorForBinaryOperator(BinaryOperatorType bop) + { + switch (bop) { + case BinaryOperatorType.Add: + return AssignmentOperatorType.Add; + case BinaryOperatorType.Subtract: + return AssignmentOperatorType.Subtract; + case BinaryOperatorType.Multiply: + return AssignmentOperatorType.Multiply; + case BinaryOperatorType.Divide: + return AssignmentOperatorType.Divide; + case BinaryOperatorType.Modulus: + return AssignmentOperatorType.Modulus; + case BinaryOperatorType.ShiftLeft: + return AssignmentOperatorType.ShiftLeft; + case BinaryOperatorType.ShiftRight: + return AssignmentOperatorType.ShiftRight; + case BinaryOperatorType.BitwiseAnd: + return AssignmentOperatorType.BitwiseAnd; + case BinaryOperatorType.BitwiseOr: + return AssignmentOperatorType.BitwiseOr; + case BinaryOperatorType.ExclusiveOr: + return AssignmentOperatorType.ExclusiveOr; + default: + return AssignmentOperatorType.Assign; + } + } + + static bool CanConvertToCompoundAssignment(Expression left) + { + MemberReferenceExpression mre = left as MemberReferenceExpression; + if (mre != null) + return IsWithoutSideEffects(mre.Target); + IndexerExpression ie = left as IndexerExpression; + if (ie != null) + return IsWithoutSideEffects(ie.Target) && ie.Arguments.All(IsWithoutSideEffects); + UnaryOperatorExpression uoe = left as UnaryOperatorExpression; + if (uoe != null && uoe.Operator == UnaryOperatorType.Dereference) + return IsWithoutSideEffects(uoe.Expression); + return IsWithoutSideEffects(left); + } + + static bool IsWithoutSideEffects(Expression left) + { + return left is ThisReferenceExpression || left is IdentifierExpression || left is TypeReferenceExpression || left is BaseReferenceExpression; + } + + static readonly Expression getMethodOrConstructorFromHandlePattern = + new TypePattern(typeof(MethodBase)).ToType().Invoke( + "GetMethodFromHandle", + new NamedNode("ldtokenNode", new LdTokenPattern("method")).ToExpression().Member("MethodHandle"), + new OptionalNode(new TypeOfExpression(new AnyNode("declaringType")).Member("TypeHandle")) + ).CastTo(new Choice { + new TypePattern(typeof(MethodInfo)), + new TypePattern(typeof(ConstructorInfo)) + }); + + public override object VisitCastExpression(CastExpression castExpression, object data) + { + base.VisitCastExpression(castExpression, data); + // Handle methodof + Match m = getMethodOrConstructorFromHandlePattern.Match(castExpression); + if (m.Success) { + MethodReference method = m.Get<AstNode>("method").Single().Annotation<MethodReference>(); + if (m.Has("declaringType")) { + Expression newNode = m.Get<AstType>("declaringType").Single().Detach().Member(method.Name); + newNode = newNode.Invoke(method.Parameters.Select(p => new TypeReferenceExpression(AstBuilder.ConvertType(p.ParameterType, p)))); + newNode.AddAnnotation(method); + m.Get<AstNode>("method").Single().ReplaceWith(newNode); + } + castExpression.ReplaceWith(m.Get<AstNode>("ldtokenNode").Single()); + } + return null; + } + + void IAstTransform.Run(AstNode node) + { + node.AcceptVisitor(this, null); + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs b/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs new file mode 100644 index 00000000..3091a109 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs @@ -0,0 +1,65 @@ +// Copyright (c) 2011 AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Threading; +using ICSharpCode.NRefactory.CSharp; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + public interface IAstTransform + { + void Run(AstNode compilationUnit); + } + + public static class TransformationPipeline + { + public static IAstTransform[] CreatePipeline(DecompilerContext context) + { + return new IAstTransform[] { + new PushNegation(), + new DelegateConstruction(context), + new PatternStatementTransform(context), + new ReplaceMethodCallsWithOperators(context), + new IntroduceUnsafeModifier(), + new AddCheckedBlocks(), + new DeclareVariables(context), // should run after most transforms that modify statements + new ConvertConstructorCallIntoInitializer(), // must run after DeclareVariables + new DecimalConstantTransform(), + new IntroduceUsingDeclarations(context), + new IntroduceExtensionMethods(context), // must run after IntroduceUsingDeclarations + new IntroduceQueryExpressions(context), // must run after IntroduceExtensionMethods + new CombineQueryExpressions(context), + new FlattenSwitchBlocks(), + }; + } + + public static void RunTransformationsUntil(AstNode node, Predicate<IAstTransform> abortCondition, DecompilerContext context) + { + if (node == null) + return; + + foreach (var transform in CreatePipeline(context)) { + context.CancellationToken.ThrowIfCancellationRequested(); + if (abortCondition != null && abortCondition(transform)) + return; + transform.Run(node); + } + } + } +} |