using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using CTA.Rules.Actions.ActionHelpers;
using CTA.Rules.Config;
using CTA.Rules.Models;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using CSharpExtensions = Microsoft.CodeAnalysis.CSharp.CSharpExtensions;
namespace CTA.Rules.Actions.Csharp
{
///
/// List of actions that can run on Class Declarations
///
public class ClassActions
{
public Func GetRemoveBaseClassAction(string baseClass)
{
ClassDeclarationSyntax RemoveBaseClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
if (node.BaseList != null)
{
SeparatedSyntaxList currentBaseTypes = node.BaseList.Types;
SeparatedSyntaxList newBaseTypes = new SeparatedSyntaxList();
foreach (var baseTypeSyntax in currentBaseTypes)
{
if (!baseTypeSyntax.GetText().ToString().Trim().Equals(baseClass))
{
newBaseTypes.Add(baseTypeSyntax);
}
}
if (!newBaseTypes.Any())
{
node = node.WithBaseList(null);
}
else
{
node = node.WithBaseList(node.BaseList.WithTypes(newBaseTypes));
}
}
return node;
}
return RemoveBaseClass;
}
public Func GetAddBaseClassAction(string baseClass)
{
ClassDeclarationSyntax AddBaseClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
if (syntaxGenerator != null)
{
node = (ClassDeclarationSyntax)syntaxGenerator.AddBaseType(node, SyntaxFactory.ParseName(baseClass));
}
else
{
var baseType = SyntaxFactory.SimpleBaseType(SyntaxFactory.ParseTypeName(baseClass));
node = node.AddBaseListTypes(baseType);
}
return node;
}
return AddBaseClass;
}
public Func GetChangeNameAction(string className)
{
// Even though this method is a duplicate of GetRenameClassAction, keep it for backwards compatibility
return GetRenameClassAction(className);
}
public Func GetRemoveAttributeAction(string attributeName)
{
ClassDeclarationSyntax RemoveAttribute(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var attributeLists = node.AttributeLists;
AttributeListSyntax attributeToRemove = null;
foreach (var attributeList in attributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (attribute.Name.ToString() == attributeName)
{
attributeToRemove = attributeList;
break;
}
}
}
if (attributeToRemove != null)
{
attributeLists = attributeLists.Remove(attributeToRemove);
}
node = node.WithAttributeLists(attributeLists);
return node;
}
return RemoveAttribute;
}
public Func GetAddAttributeAction(string attribute)
{
ClassDeclarationSyntax AddAttribute(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var attributeLists = node.AttributeLists;
attributeLists = attributeLists.Add(
SyntaxFactory.AttributeList(
SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.Attribute(SyntaxFactory.ParseName(attribute)))));
node = node.WithAttributeLists(attributeLists);
return node;
}
return AddAttribute;
}
public Func GetAddCommentAction(string comment, string dontUseCTAPrefix = null)
{
ClassDeclarationSyntax AddComment(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
return (ClassDeclarationSyntax)CommentHelper.AddCSharpComment(node, comment, dontUseCTAPrefix);
}
return AddComment;
}
public Func GetAddMethodAction(string expression)
{
ClassDeclarationSyntax AddMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members;
allMembers = allMembers.Add(SyntaxFactory.ParseMemberDeclaration(expression));
node = node.WithMembers(allMembers);
return node;
}
return AddMethod;
}
public Func GetRemoveMethodAction(string methodName)
{
//TODO what if there is operator overloading
ClassDeclarationSyntax RemoveMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType();
if (allMethods.Any())
{
var removeMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName);
if (removeMethod != null)
{
node = node.RemoveNode(removeMethod, SyntaxRemoveOptions.KeepNoTrivia);
}
}
return node;
}
return RemoveMethod;
}
public Func GetRenameClassAction(string newClassName)
{
ClassDeclarationSyntax RenameClass(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var leadingTrivia = node.GetLeadingTrivia();
var trailingTrivia = node.GetTrailingTrivia();
node = node.WithIdentifier(SyntaxFactory.Identifier(node.Identifier.LeadingTrivia, newClassName, node.Identifier.TrailingTrivia));
return node;
}
return RenameClass;
}
public Func GetReplaceMethodModifiersAction(string methodName, string modifiers)
{
ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType();
if (allMethods.Any())
{
var replaceMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName);
if (replaceMethod != null)
{
var allModifiersAreValid = modifiers.Split(new char[] { ' ', ',' }).All(m => Constants.SupportedMethodModifiers.Contains(m));
if (allModifiersAreValid)
{
SyntaxTokenList tokenList = new SyntaxTokenList(SyntaxFactory.ParseTokens(modifiers));
var newMethod = replaceMethod.WithModifiers(tokenList);
node = node.WithMembers(node.Members.Replace(replaceMethod, newMethod));
}
}
}
return node;
}
return ReplaceMethodModifiers;
}
public Func GetAddExpressionAction(string expression)
{
ClassDeclarationSyntax AddExpression(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
MemberDeclarationSyntax parsedExpression = SyntaxFactory.ParseMemberDeclaration(expression);
if (!parsedExpression.FullSpan.IsEmpty)
{
var nodeDeclarations = node.Members;
nodeDeclarations = nodeDeclarations.Insert(0, parsedExpression);
node = node.WithMembers(nodeDeclarations);
}
return node;
}
return AddExpression;
}
public Func GetRemoveConstructorInitializerAction(string baseClass)
{
ClassDeclarationSyntax RemoveConstructorInitializer(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var constructor = node.ChildNodes().FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorDeclaration));
if (constructor != null)
{
ConstructorDeclarationSyntax constructorNode = (ConstructorDeclarationSyntax)constructor;
SeparatedSyntaxList initializerArguments = constructorNode.Initializer.ArgumentList.Arguments;
SeparatedSyntaxList newArguments = new SeparatedSyntaxList();
foreach (var argument in initializerArguments)
{
if (!argument.GetText().ToString().Trim().Equals(baseClass))
{
newArguments = newArguments.Add(argument);
}
}
if (!newArguments.Any())
{
constructorNode = constructorNode.WithInitializer(null);
}
else
{
constructorNode = constructorNode.WithInitializer(SyntaxFactory.ConstructorInitializer(SyntaxKind.BaseConstructorInitializer).AddArgumentListArguments(newArguments.ToArray()));
}
node = node.ReplaceNode(constructor, constructorNode);
}
return node;
}
return RemoveConstructorInitializer;
}
public Func GetAppendConstructorExpressionAction(string expression)
{
ClassDeclarationSyntax AppendConstructorExpression(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var constructor = node.Members.FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorDeclaration));
if (constructor != null)
{
ConstructorDeclarationSyntax constructorNode = (ConstructorDeclarationSyntax)constructor;
StatementSyntax statementExpression = SyntaxFactory.ParseStatement(expression);
if (!statementExpression.FullSpan.IsEmpty)
{
constructorNode = constructorNode.AddBodyStatements(statementExpression);
node = node.ReplaceNode(constructor, constructorNode);
}
}
return node;
}
return AppendConstructorExpression;
}
public Func GetCreateConstructorAction(string types, string identifiers)
{
ClassDeclarationSyntax CreateConstructor(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var constructorName = node.Identifier.Value.ToString();
if (!string.IsNullOrWhiteSpace(constructorName))
{
var constructorNode = SyntaxFactory.ConstructorDeclaration(constructorName).AddBodyStatements().AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword));
// Add constructor parameters if provided
if (!string.IsNullOrWhiteSpace(identifiers) && !string.IsNullOrWhiteSpace(types))
{
var identifiersArray = identifiers.Split(',', StringSplitOptions.RemoveEmptyEntries);
var typesArray = types.Split(',', StringSplitOptions.RemoveEmptyEntries);
if (identifiersArray.Length == typesArray.Length)
{
List parameters = new List();
for (int i = 0; i < identifiersArray.Length; i++)
{
parameters.Add(SyntaxFactory.Parameter(SyntaxFactory.Identifier(identifiersArray[i])).WithType(SyntaxFactory.ParseTypeName(typesArray[i])));
}
constructorNode = constructorNode.AddParameterListParameters(parameters.ToArray());
}
};
node = node.AddMembers(constructorNode);
}
return node;
}
return CreateConstructor;
}
public Func GetChangeMethodNameAction(string existingMethodName, string newMethodName)
{
ClassDeclarationSyntax ChangeMethodName(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, existingMethodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var changeMethodNameFunc = methodActions.GetChangeMethodNameAction(newMethodName);
var newMethodNode = changeMethodNameFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return ChangeMethodName;
}
public Func GetChangeMethodToReturnTaskTypeAction(string methodName)
{
ClassDeclarationSyntax ChangeMethodToReturnTaskType(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var changeMethodToReturnTaskTypeActionFunc = methodActions.GetChangeMethodToReturnTaskTypeAction(methodName);
var newMethodNode = changeMethodToReturnTaskTypeActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return ChangeMethodToReturnTaskType;
}
public Func GetRemoveMethodParametersAction(string methodName)
{
ClassDeclarationSyntax RemoveMethodParameters(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var removeMethodParametersActionFunc = methodActions.GetRemoveMethodParametersAction();
var newMethodNode = removeMethodParametersActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return RemoveMethodParameters;
}
public Func GetCommentMethodAction(string methodName, string comment = null, string dontUseCTAPrefix = null)
{
ClassDeclarationSyntax CommentMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var commentMethodAction = methodActions.GetCommentMethodAction(comment, dontUseCTAPrefix);
var newMethodNode = commentMethodAction(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return CommentMethod;
}
public Func GetAddCommentsToMethodAction(string methodName, string comment, string dontUseCTAPrefix = null)
{
ClassDeclarationSyntax AddCommentsToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
if (!string.IsNullOrWhiteSpace(comment))
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var addCommentActionFunc = methodActions.GetAddCommentAction(comment, dontUseCTAPrefix);
var newMethodNode = addCommentActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
}
return node;
}
return AddCommentsToMethod;
}
public Func GetAddExpressionToMethodAction(string methodName, string expression)
{
ClassDeclarationSyntax AddExpressionToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var addExpressionToMethodAction = methodActions.GetAddExpressionToMethodAction(expression);
var newMethodNode = addExpressionToMethodAction(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return AddExpressionToMethod;
}
public Func GetAddParametersToMethodAction(string methodName, string types, string identifiers)
{
ClassDeclarationSyntax AddParametersToMethod(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
MethodDeclarationActions methodActions = new MethodDeclarationActions();
var addParametersToMethodAction = methodActions.GetAddParametersToMethodAction(types, identifiers);
var newMethodNode = addParametersToMethodAction(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return AddParametersToMethod;
}
public Func GetReplaceMvcControllerMethodsBodyAction(string expression)
{
ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public))
.Select(m => GetMethodId(m)).ToList();
foreach (var method in allMethods)
{
var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method);
var originalMethod = currentMethod;
bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier);
bool voidReturn = currentMethod.Modifiers.Any(mod => mod.Text == Constants.VoidModifier);
string returnType = "";
if (!voidReturn)
{
returnType = asyncCheck ? Constants.TaskActionResult : Constants.ActionResult;
}
var newExpression = expression;
if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck)
{
newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " ");
newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord);
}
currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
if (!string.IsNullOrEmpty(returnType))
{
currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
}
node = node.ReplaceNode(originalMethod, currentMethod.NormalizeWhitespace().WithLeadingTrivia(currentMethod.GetLeadingTrivia()));
}
return node;
}
return ReplaceMethodModifiers;
}
public Func GetReplaceWebApiControllerMethodsBodyAction(string expression)
{
ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public))
.Select(m => GetMethodId(m)).ToList();
foreach (var method in allMethods)
{
var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method);
var originalMethod = currentMethod;
bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier);
//IHttpActionResult is a catch all return type
string returnType = asyncCheck ? Constants.TaskIHttpActionResult : Constants.IHttpActionResult;
var newExpression = expression;
if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck)
{
newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " ");
newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord);
}
currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
node = node.ReplaceNode(originalMethod, currentMethod.NormalizeWhitespace().WithLeadingTrivia(currentMethod.GetLeadingTrivia()));
}
return node;
}
return ReplaceMethodModifiers;
}
public Func GetReplaceCoreControllerMethodsBodyAction(string expression)
{
ClassDeclarationSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, ClassDeclarationSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType().Where(m => m.Modifiers.Any(mod => mod.Text == Constants.Public))
.Select(m => GetMethodId(m)).ToList();
foreach (var method in allMethods)
{
var currentMethod = node.Members.OfType().FirstOrDefault(m => GetMethodId(m) == method);
var originalMethod = currentMethod;
bool asyncCheck = currentMethod.Modifiers.Any(mod => mod.Text == Constants.AsyncModifier);
bool voidReturn = currentMethod.Modifiers.Any(mod => mod.Text == Constants.VoidModifier);
string returnType = "";
if (!voidReturn)
{
returnType = asyncCheck ? Constants.TaskIActionResult : Constants.IActionResult;
}
var newExpression = expression;
if (expression.Contains(Constants.MonolithService + "." + Constants.CreateRequest) && asyncCheck)
{
newExpression = expression.Insert(expression.IndexOf(Constants.MonolithService), Constants.Await + " ");
newExpression = newExpression.Insert(newExpression.IndexOf(Constants.CreateRequest) + Constants.CreateRequest.Length, Constants.AsyncWord);
}
currentMethod = currentMethod.WithBody(null).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
currentMethod = currentMethod.WithBody(SyntaxFactory.Block(SyntaxFactory.ParseStatement(newExpression))).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
if (!string.IsNullOrEmpty(returnType))
{
currentMethod = currentMethod.WithReturnType(SyntaxFactory.ParseTypeName(returnType)).WithLeadingTrivia(currentMethod.GetLeadingTrivia());
}
node = node.ReplaceNode(originalMethod, currentMethod.WithLeadingTrivia(currentMethod.GetLeadingTrivia()));
}
return node;
}
return ReplaceMethodModifiers;
}
private string GetMethodId(MethodDeclarationSyntax method)
{
return $"{method.Identifier}{method.ParameterList}";
}
private MethodDeclarationSyntax GetMethodNode(ClassDeclarationSyntax node, string methodName)
{
var methodNodeList = node.DescendantNodes().OfType().Where(method => method.Identifier.Text == methodName);
if (methodNodeList != null && methodNodeList.Count() > 1)
{
return null;
}
return methodNodeList.FirstOrDefault();
}
}
}