using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using CTA.Rules.Actions.ActionHelpers; using CTA.Rules.Config; using CTA.Rules.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; using CSharpExtensions = Microsoft.CodeAnalysis.CSharp.CSharpExtensions; namespace CTA.Rules.Actions.Csharp { /// /// List of actions that can run on Class Declarations /// public class ClassActions { public Func GetRemoveBaseClassAction(string baseClass) { ClassDeclarationSyntax RemoveBaseClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { if (node.BaseList != null) { SeparatedSyntaxList currentBaseTypes = node.BaseList.Types; SeparatedSyntaxList newBaseTypes = new SeparatedSyntaxList(); foreach (var baseTypeSyntax in currentBaseTypes) { if (!baseTypeSyntax.GetText().ToString().Trim().Equals(baseClass)) { newBaseTypes.Add(baseTypeSyntax); } } if (!newBaseTypes.Any()) { node = node.WithBaseList(null); } else { node = node.WithBaseList(node.BaseList.WithTypes(newBaseTypes)); } } return node; } return RemoveBaseClass; } public Func GetAddBaseClassAction(string baseClass) { ClassDeclarationSyntax AddBaseClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { if (syntaxGenerator != null) { node = (ClassDeclarationSyntax)syntaxGenerator.AddBaseType(node, SyntaxFactory.ParseName(baseClass)); } else { var baseType = SyntaxFactory.SimpleBaseType(SyntaxFactory.ParseTypeName(baseClass)); node = node.AddBaseListTypes(baseType); } return node; } return AddBaseClass; } public Func GetChangeNameAction(string className) { // Even though this method is a duplicate of GetRenameClassAction, keep it for backwards compatibility return GetRenameClassAction(className); } public Func GetRemoveAttributeAction(string attributeName) { ClassDeclarationSyntax RemoveAttribute(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var attributeLists = node.AttributeLists; AttributeListSyntax attributeToRemove = null; foreach (var attributeList in attributeLists) { foreach (var attribute in attributeList.Attributes) { if (attribute.Name.ToString() == attributeName) { attributeToRemove = attributeList; break; } } } if (attributeToRemove != null) { attributeLists = attributeLists.Remove(attributeToRemove); } node = node.WithAttributeLists(attributeLists); return node; } return RemoveAttribute; } public Func GetAddAttributeAction(string attribute) { ClassDeclarationSyntax AddAttribute(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var attributeLists = node.AttributeLists; attributeLists = attributeLists.Add( SyntaxFactory.AttributeList( SyntaxFactory.SingletonSeparatedList( SyntaxFactory.Attribute(SyntaxFactory.ParseName(attribute))))); node = node.WithAttributeLists(attributeLists); return node; } return AddAttribute; } public Func GetAddCommentAction(string comment, string dontUseCTAPrefix = null) { ClassDeclarationSyntax AddComment(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { return (ClassDeclarationSyntax)CommentHelper.AddCSharpComment(node, comment, dontUseCTAPrefix); } return AddComment; } public Func GetAddMethodAction(string expression) { ClassDeclarationSyntax AddMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members; allMembers = allMembers.Add(SyntaxFactory.ParseMemberDeclaration(expression)); node = node.WithMembers(allMembers); return node; } return AddMethod; } public Func GetRemoveMethodAction(string methodName) { //TODO what if there is operator overloading ClassDeclarationSyntax RemoveMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members.ToList(); var allMethods = allMembers.OfType(); if (allMethods.Any()) { var removeMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName); if (removeMethod != null) { node = node.RemoveNode(removeMethod, SyntaxRemoveOptions.KeepNoTrivia); } } return node; } return RemoveMethod; } public Func GetRenameClassAction(string newClassName) { ClassDeclarationSyntax RenameClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var leadingTrivia = node.GetLeadingTrivia(); var trailingTrivia = node.GetTrailingTrivia(); node = node.WithIdentifier(SyntaxFactory.Identifier(node.Identifier.LeadingTrivia, newClassName, node.Identifier.TrailingTrivia)); return node; } return RenameClass; } public Func GetReplaceMethodModifiersAction(string methodName, string modifiers) { ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members.ToList(); var allMethods = allMembers.OfType(); if (allMethods.Any()) { var replaceMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName); if (replaceMethod != null) { var allModifiersAreValid = modifiers.Split(new char[] { ' ', ',' }).All(m => Constants.SupportedMethodModifiers.Contains(m)); if (allModifiersAreValid) { SyntaxTokenList tokenList = new SyntaxTokenList(SyntaxFactory.ParseTokens(modifiers)); var newMethod = replaceMethod.WithModifiers(tokenList); node = node.WithMembers(node.Members.Replace(replaceMethod, newMethod)); } } } return node; } return ReplaceMethodModifiers; } public Func GetAddExpressionAction(string expression) { ClassDeclarationSyntax AddExpression(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { MemberDeclarationSyntax parsedExpression = SyntaxFactory.ParseMemberDeclaration(expression); if (!parsedExpression.FullSpan.IsEmpty) { var nodeDeclarations = node.Members; nodeDeclarations = nodeDeclarations.Insert(0, parsedExpression); node = node.WithMembers(nodeDeclarations); } return node; } return AddExpression; } public Func GetRemoveConstructorInitializerAction(string baseClass) { ClassDeclarationSyntax RemoveConstructorInitializer(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var constructor = node.ChildNodes().FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorDeclaration)); if (constructor != null) { ConstructorDeclarationSyntax constructorNode = (ConstructorDeclarationSyntax)constructor; SeparatedSyntaxList initializerArguments = constructorNode.Initializer.ArgumentList.Arguments; SeparatedSyntaxList newArguments = new SeparatedSyntaxList(); foreach (var argument in initializerArguments) { if (!argument.GetText().ToString().Trim().Equals(baseClass)) { newArguments = newArguments.Add(argument); } } if (!newArguments.Any()) { constructorNode = constructorNode.WithInitializer(null); } else { constructorNode = constructorNode.WithInitializer(SyntaxFactory.ConstructorInitializer(SyntaxKind.BaseConstructorInitializer).AddArgumentListArguments(newArguments.ToArray())); } node = node.ReplaceNode(constructor, constructorNode); } return node; } return RemoveConstructorInitializer; } public Func GetAppendConstructorExpressionAction(string expression) { ClassDeclarationSyntax AppendConstructorExpression(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var constructor = node.Members.FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorDeclaration)); if (constructor != null) { ConstructorDeclarationSyntax constructorNode = (ConstructorDeclarationSyntax)constructor; StatementSyntax statementExpression = SyntaxFactory.ParseStatement(expression); if (!statementExpression.FullSpan.IsEmpty) { constructorNode = constructorNode.AddBodyStatements(statementExpression); node = node.ReplaceNode(constructor, constructorNode); } } return node; } return AppendConstructorExpression; } public Func GetCreateConstructorAction(string types, string identifiers) { ClassDeclarationSyntax CreateConstructor(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var constructorName = node.Identifier.Value.ToString(); if (!string.IsNullOrWhiteSpace(constructorName)) { var constructorNode = SyntaxFactory.ConstructorDeclaration(constructorName).AddBodyStatements().AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword)); // Add constructor parameters if provided if (!string.IsNullOrWhiteSpace(identifiers) && !string.IsNullOrWhiteSpace(types)) { var identifiersArray = identifiers.Split(',', StringSplitOptions.RemoveEmptyEntries); var typesArray = types.Split(',', StringSplitOptions.RemoveEmptyEntries); if (identifiersArray.Length == typesArray.Length) { List parameters = new List(); for (int i = 0; i < identifiersArray.Length; i++) { parameters.Add(SyntaxFactory.Parameter(SyntaxFactory.Identifier(identifiersArray[i])).WithType(SyntaxFactory.ParseTypeName(typesArray[i]))); } constructorNode = constructorNode.AddParameterListParameters(parameters.ToArray()); } }; node = node.AddMembers(constructorNode); } return node; } return CreateConstructor; } public Func GetChangeMethodNameAction(string existingMethodName, string newMethodName) { ClassDeclarationSyntax ChangeMethodName(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, existingMethodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var changeMethodNameFunc = methodActions.GetChangeMethodNameAction(newMethodName); var newMethodNode = changeMethodNameFunc(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return ChangeMethodName; } public Func GetChangeMethodToReturnTaskTypeAction(string methodName) { ClassDeclarationSyntax ChangeMethodToReturnTaskType(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var changeMethodToReturnTaskTypeActionFunc = methodActions.GetChangeMethodToReturnTaskTypeAction(methodName); var newMethodNode = changeMethodToReturnTaskTypeActionFunc(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return ChangeMethodToReturnTaskType; } public Func GetRemoveMethodParametersAction(string methodName) { ClassDeclarationSyntax RemoveMethodParameters(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var removeMethodParametersActionFunc = methodActions.GetRemoveMethodParametersAction(); var newMethodNode = removeMethodParametersActionFunc(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return RemoveMethodParameters; } public Func GetCommentMethodAction(string methodName, string comment = null, string dontUseCTAPrefix = null) { ClassDeclarationSyntax CommentMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var commentMethodAction = methodActions.GetCommentMethodAction(comment, dontUseCTAPrefix); var newMethodNode = commentMethodAction(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return CommentMethod; } public Func GetAddCommentsToMethodAction(string methodName, string comment, string dontUseCTAPrefix = null) { ClassDeclarationSyntax AddCommentsToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { if (!string.IsNullOrWhiteSpace(comment)) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var addCommentActionFunc = methodActions.GetAddCommentAction(comment, dontUseCTAPrefix); var newMethodNode = addCommentActionFunc(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } } return node; } return AddCommentsToMethod; } public Func GetAddExpressionToMethodAction(string methodName, string expression) { ClassDeclarationSyntax AddExpressionToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var addExpressionToMethodAction = methodActions.GetAddExpressionToMethodAction(expression); var newMethodNode = addExpressionToMethodAction(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return AddExpressionToMethod; } public Func GetAddParametersToMethodAction(string methodName, string types, string identifiers) { ClassDeclarationSyntax AddParametersToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { // if we have more than one method with same name return without making changes var methodNode = GetMethodNode(node, methodName); if (methodNode != null) { MethodDeclarationActions methodActions = new MethodDeclarationActions(); var addParametersToMethodAction = methodActions.GetAddParametersToMethodAction(types, identifiers); var newMethodNode = addParametersToMethodAction(syntaxGenerator, methodNode); node = node.ReplaceNode(methodNode, newMethodNode); } return node; } return AddParametersToMethod; } public Func GetReplaceMvcControllerMethodsBodyAction(string expression) { ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members.ToList(); var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public)) .Select(m => GetMethodId(m)).ToList(); foreach (var method in allMethods) { var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method); var originalMethod = currentMethod; bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier); bool voidReturn = currentMethod.Modifiers.Any(mod => mod.Text == Constants.VoidModifier); string returnType = ""; if (!voidReturn) { returnType = asyncCheck ? Constants.TaskActionResult : Constants.ActionResult; } var newExpression = expression; if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck) { newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " "); newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord); } currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); if (!string.IsNullOrEmpty(returnType)) { currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); } node = node.ReplaceNode(originalMethod, currentMethod.NormalizeWhitespace().WithLeadingTrivia(currentMethod.GetLeadingTrivia())); } return node; } return ReplaceMethodModifiers; } public Func GetReplaceWebApiControllerMethodsBodyAction(string expression) { ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members.ToList(); var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public)) .Select(m => GetMethodId(m)).ToList(); foreach (var method in allMethods) { var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method); var originalMethod = currentMethod; bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier); //IHttpActionResult is a catch all return type string returnType = asyncCheck ? Constants.TaskIHttpActionResult : Constants.IHttpActionResult; var newExpression = expression; if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck) { newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " "); newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord); } currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); node = node.ReplaceNode(originalMethod, currentMethod.NormalizeWhitespace().WithLeadingTrivia(currentMethod.GetLeadingTrivia())); } return node; } return ReplaceMethodModifiers; } public Func GetReplaceCoreControllerMethodsBodyAction(string expression) { ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node) { var allMembers = node.Members.ToList(); var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public)) .Select(m => GetMethodId(m)).ToList(); foreach (var method in allMethods) { var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method); var originalMethod = currentMethod; bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier); bool voidReturn = currentMethod.Modifiers.Any(mod => mod.Text == Constants.VoidModifier); string returnType = ""; if (!voidReturn) { returnType = asyncCheck ? Constants.TaskIActionResult : Constants.IActionResult; } var newExpression = expression; if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck) { newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " "); newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord); } currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); if (!string.IsNullOrEmpty(returnType)) { currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia()); } node = node.ReplaceNode(originalMethod, currentMethod.WithLeadingTrivia(currentMethod.GetLeadingTrivia())); } return node; } return ReplaceMethodModifiers; } private string GetMethodId(MethodDeclarationSyntax method) { return $"{method.Identifier}{method.ParameterList}"; } private MethodDeclarationSyntax GetMethodNode(ClassDeclarationSyntax node, string methodName) { var methodNodeList = node.DescendantNodes().OfType().Where(method => method.Identifier.Text == methodName); if (methodNodeList != null && methodNodeList.Count() > 1) { return null; } return methodNodeList.FirstOrDefault(); } } }