using System;
using System.Collections.Generic;
using System.Linq;
using CTA.Rules.Config;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.VisualBasic;
using Microsoft.CodeAnalysis.VisualBasic.Syntax;
using Microsoft.CodeAnalysis.Editing;
using CTA.Rules.Actions.ActionHelpers;
namespace CTA.Rules.Actions.VisualBasic
{
///
/// List of actions that can run on Class Blocks
///
public class TypeBlockActions
{
public Func GetRemoveBaseClassAction(string baseClass)
{
TypeBlockSyntax RemoveBaseClass(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var currentBaseTypes = node.Inherits.FirstOrDefault()?.Types ?? new SeparatedSyntaxList();
SeparatedSyntaxList newBaseTypes = new SeparatedSyntaxList();
foreach (var baseTypeSyntax in currentBaseTypes)
{
if (!baseTypeSyntax.GetText().ToString().Trim().Equals(baseClass))
{
newBaseTypes.Add(baseTypeSyntax);
}
}
node = node.WithInherits(new SyntaxList
{
SyntaxFactory.InheritsStatement().WithTypes(newBaseTypes)
});
return node;
}
return RemoveBaseClass;
}
public Func GetAddBaseClassAction(string baseClass)
{
TypeBlockSyntax AddBaseClass(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
if (syntaxGenerator != null)
{
node = (TypeBlockSyntax)syntaxGenerator.AddBaseType(node, SyntaxFactory.ParseName(baseClass));
}
else
{
var baseType = SyntaxFactory.InheritsStatement(SyntaxFactory.ParseTypeName(baseClass));
node = node.WithInherits(new SyntaxList(baseType));
}
return node;
}
return AddBaseClass;
}
public Func GetChangeNameAction(string className)
{
TypeBlockSyntax ChangeName(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
node = node.WithBlockStatement(node.BlockStatement.WithIdentifier(SyntaxFactory.Identifier(className)))
.NormalizeWhitespace();
return node;
}
return ChangeName;
}
public Func GetRemoveAttributeAction(string attributeName)
{
TypeBlockSyntax RemoveAttribute(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var attributeLists = node.BlockStatement.AttributeLists;
AttributeListSyntax attributeToRemove = null;
foreach (var attributeList in attributeLists)
{
foreach (var attribute in attributeList.Attributes)
{
if (attribute.Name.ToString() == attributeName)
{
attributeToRemove = attributeList;
break;
}
}
}
if (attributeToRemove != null)
{
attributeLists = attributeLists.Remove(attributeToRemove);
}
node = node.WithBlockStatement(node.BlockStatement.WithAttributeLists(attributeLists))
.NormalizeWhitespace();
return node;
}
return RemoveAttribute;
}
public Func GetAddAttributeAction(string attribute)
{
TypeBlockSyntax AddAttribute(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var attributeLists = node.BlockStatement.AttributeLists;
attributeLists = attributeLists.Add(
SyntaxFactory.AttributeList(
SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.Attribute(SyntaxFactory.ParseName(attribute)))));
node = node.WithBlockStatement(node.BlockStatement.WithAttributeLists(attributeLists))
.NormalizeWhitespace();
return node;
}
return AddAttribute;
}
public Func GetAddCommentAction(string comment, string dontUseCTAPrefix = null)
{
TypeBlockSyntax AddComment(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
return (TypeBlockSyntax)CommentHelper.AddVBComment(node, comment, dontUseCTAPrefix);
}
return AddComment;
}
public Func GetAddMethodAction(string expression)
{
TypeBlockSyntax AddMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var methodBlockSyntax = SyntaxFactory.ParseSyntaxTree(expression).GetRoot().DescendantNodes()
.OfType().FirstOrDefault();
if (methodBlockSyntax != null)
{
node = node.AddMembers(methodBlockSyntax);
}
return node.NormalizeWhitespace();
}
return AddMethod;
}
public Func GetRemoveMethodAction(string methodName)
{
//TODO what if there is operator overloading
TypeBlockSyntax RemoveMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType();
if (allMethods.Any())
{
var removeMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName);
if (removeMethod != null)
{
node = node.RemoveNode(removeMethod, SyntaxRemoveOptions.KeepNoTrivia).NormalizeWhitespace();
}
}
return node;
}
return RemoveMethod;
}
public Func GetRenameClassAction(string newClassName)
{
TypeBlockSyntax RenameClass(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
node = node.WithBlockStatement(node.BlockStatement.WithIdentifier(SyntaxFactory.Identifier(newClassName))).NormalizeWhitespace();
return node;
}
return RenameClass;
}
public Func GetReplaceMethodModifiersAction(string methodName, string modifiers)
{
TypeBlockSyntax ReplaceMethodModifiers(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var allMethods = node.Members.OfType();
if (allMethods.Any())
{
var replaceMethod = allMethods.FirstOrDefault(m => m.Identifier.ToString() == methodName);
if (replaceMethod != null)
{
var allModifiers = modifiers.Split(new char[] { ' ', ',' });
if (allModifiers.All(m => Constants.SupportedVbMethodModifiers.Contains(m)))
{
SyntaxTokenList tokenList = new SyntaxTokenList();
foreach (string m in allModifiers)
{
if (m == "Async")
{
// for some reason syntax factory can't parse that async is a keyword
tokenList = tokenList.Add(SyntaxFactory.Token(SyntaxKind.AsyncKeyword));
}
else
{
tokenList = tokenList.Add(SyntaxFactory.ParseToken(m));
}
}
var newMethod = replaceMethod.WithModifiers(tokenList);
node = node.WithMembers(node.Members.Replace(replaceMethod, newMethod)).NormalizeWhitespace();
}
}
}
return node;
}
return ReplaceMethodModifiers;
}
public Func GetAddExpressionAction(string expression)
{
TypeBlockSyntax AddExpression(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var parsedExpression = SyntaxFactory.ParseExecutableStatement(expression);
if (!parsedExpression.FullSpan.IsEmpty)
{
var nodeDeclarations = node.Members;
nodeDeclarations = nodeDeclarations.Insert(0, parsedExpression);
node = node.WithMembers(nodeDeclarations).NormalizeWhitespace();
}
return node;
}
return AddExpression;
}
public Func GetRemoveConstructorInitializerAction(string initializerArgument)
{
TypeBlockSyntax RemoveConstructorInitializer(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var constructor = node.ChildNodes().FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorBlock));
if (constructor != null)
{
var constructorNode = (ConstructorBlockSyntax)constructor;
var newArguments = new SeparatedSyntaxList();
// base initializers should be the first statement
var firstStatement = constructorNode.Statements.FirstOrDefault();
if (firstStatement != null &&
firstStatement.DescendantNodes().Any(s => s.IsKind(SyntaxKind.MyBaseExpression)))
{
var arguments = firstStatement.DescendantNodes().OfType()
.FirstOrDefault();
if (arguments != null)
{
foreach (var arg in arguments.Arguments)
{
if (!arg.GetText().ToString().Trim().Equals(initializerArgument))
{
newArguments = newArguments.Add(arg);
}
}
if (newArguments.Any())
{
node = node.ReplaceNode(arguments, SyntaxFactory.ArgumentList(newArguments))
.NormalizeWhitespace();
}
}
}
}
return node;
}
return RemoveConstructorInitializer;
}
public Func GetAppendConstructorExpressionAction(string expression)
{
TypeBlockSyntax AppendConstructorExpression(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
var constructor = node.Members.FirstOrDefault(c => c.IsKind(SyntaxKind.ConstructorBlock));
if (constructor != null)
{
ConstructorBlockSyntax constructorNode = (ConstructorBlockSyntax)constructor;
StatementSyntax statementExpression = SyntaxFactory.ParseExecutableStatement(expression);
if (!statementExpression.FullSpan.IsEmpty)
{
constructorNode = constructorNode.AddStatements(statementExpression);
node = node.ReplaceNode(constructor, constructorNode).NormalizeWhitespace();
}
}
return node;
}
return AppendConstructorExpression;
}
public Func GetCreateConstructorAction(string types, string identifiers)
{
TypeBlockSyntax CreateConstructor(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// constructors in vb are just named new
var constructorStatementNode = SyntaxFactory.SubNewStatement()
.AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword));
// Add constructor parameters if provided
if (!string.IsNullOrWhiteSpace(identifiers) && !string.IsNullOrWhiteSpace(types))
{
var identifiersArray = identifiers.Split(',', StringSplitOptions.RemoveEmptyEntries);
var typesArray = types.Split(',', StringSplitOptions.RemoveEmptyEntries);
if (identifiersArray.Length == typesArray.Length)
{
List parameters = new List();
for (var i = 0; i < identifiersArray.Length; i++)
{
parameters.Add(SyntaxFactory
.Parameter(SyntaxFactory.ModifiedIdentifier(identifiersArray[i]))
.WithAsClause(SyntaxFactory.SimpleAsClause(SyntaxFactory.ParseTypeName(typesArray[i]))));
}
constructorStatementNode = constructorStatementNode.AddParameterListParameters(parameters.ToArray());
}
}
var constructorBlock = SyntaxFactory.ConstructorBlock(constructorStatementNode);
node = node.AddMembers(constructorBlock).NormalizeWhitespace();
return node;
}
return CreateConstructor;
}
public Func GetChangeMethodNameAction(string existingMethodName, string newMethodName)
{
TypeBlockSyntax ChangeMethodName(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, existingMethodName);
if (methodNode != null)
{
var methodActions = new MethodBlockActions();
var changeMethodNameFunc = methodActions.GetChangeMethodNameAction(newMethodName);
var newMethodNode = changeMethodNameFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return ChangeMethodName;
}
public Func GetChangeMethodToReturnTaskTypeAction(string methodName)
{
TypeBlockSyntax ChangeMethodToReturnTaskType(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
var methodActions = new MethodBlockActions();
var changeMethodToReturnTaskTypeActionFunc = methodActions.GetChangeMethodToReturnTaskTypeAction();
var newMethodNode = changeMethodToReturnTaskTypeActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return ChangeMethodToReturnTaskType;
}
public Func GetRemoveMethodParametersAction(string methodName)
{
TypeBlockSyntax RemoveMethodParameters(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
var parameters = methodNode.SubOrFunctionStatement.ParameterList.Parameters;
MethodBlockActions methodActions = new MethodBlockActions();
var removeMethodParametersActionFunc = methodActions.GetRemoveMethodParametersAction();
var newMethodNode = removeMethodParametersActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node.NormalizeWhitespace();
}
return RemoveMethodParameters;
}
public Func GetCommentMethodAction(string methodName, string comment = null, string dontUseCTAPrefix = null)
{
TypeBlockSyntax CommentMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
var methodActions = new MethodBlockActions();
var commentMethodAction = methodActions.GetCommentMethodAction(comment, dontUseCTAPrefix);
var newMethodNode = commentMethodAction(syntaxGenerator, methodNode);
var methodStatementComment =
SyntaxFactory.CommentTrivia($"' {newMethodNode.SubOrFunctionStatement.ToFullString()}");
var methodBodyComment = newMethodNode.EndSubOrFunctionStatement.GetLeadingTrivia();
var methodEndStatementComment =
SyntaxFactory.CommentTrivia($"' {newMethodNode.EndSubOrFunctionStatement.ToString()}");
var trivia = new SyntaxTriviaList();
trivia = trivia.Add(methodStatementComment);
trivia = trivia.AddRange(methodBodyComment);
trivia = trivia.Add(methodEndStatementComment);
node = node.RemoveNode(methodNode, SyntaxRemoveOptions.KeepNoTrivia);
node = node.WithEndBlockStatement(node.EndBlockStatement.WithLeadingTrivia(trivia));
}
return node.NormalizeWhitespace();
}
return CommentMethod;
}
public Func GetAddCommentsToMethodAction(string methodName, string comment, string dontUseCTAPrefix = null)
{
TypeBlockSyntax AddCommentsToMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
if (!string.IsNullOrWhiteSpace(comment))
{
var methodActions = new MethodBlockActions();
var addCommentActionFunc = methodActions.GetAddCommentAction(comment, dontUseCTAPrefix);
var newMethodNode = addCommentActionFunc(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
}
return node;
}
return AddCommentsToMethod;
}
public Func GetAddExpressionToMethodAction(string methodName, string expression)
{
TypeBlockSyntax AddExpressionToMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
var methodActions = new MethodBlockActions();
var addExpressionToMethodAction = methodActions.GetAddExpressionToMethodAction(expression);
var newMethodNode = addExpressionToMethodAction(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode);
}
return node;
}
return AddExpressionToMethod;
}
public Func GetAddParametersToMethodAction(string methodName, string types, string identifiers)
{
TypeBlockSyntax AddParametersToMethod(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
// if we have more than one method with same name return without making changes
var methodNode = GetMethodNode(node, methodName);
if (methodNode != null)
{
var methodActions = new MethodBlockActions();
var addParametersToMethodAction = methodActions.GetAddParametersToMethodAction(types, identifiers);
var newMethodNode = addParametersToMethodAction(syntaxGenerator, methodNode);
node = node.ReplaceNode(methodNode, newMethodNode).NormalizeWhitespace();
}
return node;
}
return AddParametersToMethod;
}
public Func GetReplaceMvcControllerMethodsBodyAction(string expression)
{
TypeBlockSyntax ReplaceMvcControllerMethodsBodyFunc(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
return node;
}
return ReplaceMvcControllerMethodsBodyFunc;
}
public Func GetReplaceWebApiControllerMethodsBodyAction(string expression)
{
TypeBlockSyntax ReplaceMethodBodyFunc(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
return AddCommentToPublicMethods(node, expression);
}
return ReplaceMethodBodyFunc;
}
public Func GetReplaceCoreControllerMethodsBodyAction(string expression)
{
TypeBlockSyntax ReplaceCoreControllerMethodsBody(SyntaxGenerator syntaxGenerator, TypeBlockSyntax node)
{
return AddCommentToPublicMethods(node, expression);
}
return ReplaceCoreControllerMethodsBody;
}
private TypeBlockSyntax AddCommentToPublicMethods(TypeBlockSyntax node, string expression)
{
var comment = string.Format(Constants.VbCommentFormat, $"Replace method body with {expression}");
var allMembers = node.Members.ToList();
var allMethods = allMembers.OfType()
.Where(m => m.SubOrFunctionStatement.Modifiers.Any(mod => mod.IsKind(SyntaxKind.PublicKeyword)))
.Select(mb => GetMethodId(mb.SubOrFunctionStatement)).ToList();
foreach (var method in allMethods)
{
var currentMethodStatement = node.DescendantNodes().OfType()
.FirstOrDefault(m => GetMethodId(m) == method);
var originalMethod = currentMethodStatement;
if (currentMethodStatement != null)
{
var trivia = currentMethodStatement.GetLeadingTrivia();
trivia = trivia.Add(SyntaxFactory.SyntaxTrivia(SyntaxKind.CommentTrivia, comment));
currentMethodStatement = currentMethodStatement.WithLeadingTrivia(trivia).NormalizeWhitespace();
node = node.ReplaceNode(originalMethod, currentMethodStatement);
}
}
return node;
}
private string GetMethodId(MethodStatementSyntax method)
{
return $"{method.Identifier}{method.ParameterList}";
}
private MethodBlockSyntax GetMethodNode(TypeBlockSyntax node, string methodName)
{
var methodNodeList = node.DescendantNodes().OfType()
.Where(method => method.SubOrFunctionStatement.Identifier.Text == methodName);
if (methodNodeList != null && methodNodeList.Count() > 1)
{
return null;
}
return methodNodeList.FirstOrDefault();
}
}
}