From ee29047cd103e040992600327c5c9985e1f1a244 Mon Sep 17 00:00:00 2001 From: UNV Date: Thu, 6 Nov 2025 01:09:53 +0300 Subject: [PATCH 1/2] Replacing deprecated AllIcons. Refactoring. --- .../console/pydev/PyCodeCompletionImages.java | 21 +- .../com/jetbrains/pyqt/CompileQrcAction.java | 2 +- .../python/impl/PyIconDescriptorUpdater.java | 81 +- .../codeInsight/PyLineMarkerProvider.java | 437 +- .../PyDictKeyNamesCompletionContributor.java | 360 +- .../PyParameterCompletionContributor.java | 61 +- .../PyUserSkeletonsLineMarkerProvider.java | 84 +- .../doctest/PyDocstringLanguageInjector.java | 150 +- .../impl/packaging/PyPackageManagerUI.java | 601 +-- .../com/jetbrains/python/impl/psi/PyUtil.java | 3755 +++++++++-------- .../python/impl/psi/impl/PyBoundFunction.java | 9 +- .../python/impl/psi/impl/PyFileImpl.java | 1428 ++++--- .../python/impl/psi/impl/PyFunctionImpl.java | 1504 +++---- .../impl/PyStringLiteralExpressionImpl.java | 905 ++-- .../impl/references/PyQualifiedReference.java | 247 +- .../psi/impl/references/PyReferenceImpl.java | 1429 +++---- .../impl/psi/types/PyStructuralType.java | 97 +- .../classes/ui/PyMemberSelectionTable.java | 65 +- .../impl/structureView/PyFieldsFilter.java | 53 +- .../PyInheritedMembersFilter.java | 76 +- .../python/codeInsight/PyCustomMember.java | 461 +- .../java/com/jetbrains/python/psi/PyFile.java | 178 +- 22 files changed, 6086 insertions(+), 5918 deletions(-) diff --git a/python-debugger/src/main/java/com/jetbrains/python/console/pydev/PyCodeCompletionImages.java b/python-debugger/src/main/java/com/jetbrains/python/console/pydev/PyCodeCompletionImages.java index f18f1b90..ebd11f2b 100644 --- a/python-debugger/src/main/java/com/jetbrains/python/console/pydev/PyCodeCompletionImages.java +++ b/python-debugger/src/main/java/com/jetbrains/python/console/pydev/PyCodeCompletionImages.java @@ -1,27 +1,20 @@ package com.jetbrains.python.console.pydev; -import consulo.application.AllIcons; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.image.Image; import jakarta.annotation.Nullable; public class PyCodeCompletionImages { - /** * Returns an image for the given type - * @param type - * @return */ @Nullable - public static Image getImageForType(int type){ - switch (type) { - case IToken.TYPE_CLASS: - return AllIcons.Nodes.Class; - case IToken.TYPE_FUNCTION: - return AllIcons.Nodes.Method; - default: - return null; - } + public static Image getImageForType(int type) { + return switch (type) { + case IToken.TYPE_CLASS -> PlatformIconGroup.nodesClass(); + case IToken.TYPE_FUNCTION -> PlatformIconGroup.nodesMethod(); + default -> null; + }; } - } diff --git a/python-impl/src/main/java/com/jetbrains/pyqt/CompileQrcAction.java b/python-impl/src/main/java/com/jetbrains/pyqt/CompileQrcAction.java index 882310be..eac3acf3 100644 --- a/python-impl/src/main/java/com/jetbrains/pyqt/CompileQrcAction.java +++ b/python-impl/src/main/java/com/jetbrains/pyqt/CompileQrcAction.java @@ -44,7 +44,7 @@ public class CompileQrcAction extends AnAction { public void actionPerformed(AnActionEvent e) { Project project = e.getData(Project.KEY); VirtualFile[] vFiles = e.getRequiredData(VirtualFile.KEY_OF_ARRAY); - Module module = e.getData(Module.KEY); + Module module = e.getRequiredData(Module.KEY); String path = QtFileType.findQtTool(module, "pyrcc4"); if (path == null) { path = QtFileType.findQtTool(module, "pyside-rcc"); diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/PyIconDescriptorUpdater.java b/python-impl/src/main/java/com/jetbrains/python/impl/PyIconDescriptorUpdater.java index 0d2a1c29..14d4127c 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/PyIconDescriptorUpdater.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/PyIconDescriptorUpdater.java @@ -13,69 +13,70 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.jetbrains.python.impl; import com.jetbrains.python.PyNames; import com.jetbrains.python.psi.Property; import com.jetbrains.python.psi.PyClass; import com.jetbrains.python.psi.PyFunction; +import consulo.annotation.access.RequiredReadAction; import consulo.annotation.component.ExtensionImpl; -import consulo.application.AllIcons; import consulo.language.icon.IconDescriptor; import consulo.language.icon.IconDescriptorUpdater; import consulo.language.psi.PsiDirectory; import consulo.language.psi.PsiElement; import consulo.module.content.ProjectRootManager; +import consulo.platform.base.icon.PlatformIconGroup; +import consulo.python.impl.icon.PythonImplIconGroup; import consulo.ui.image.Image; -import consulo.util.lang.Comparing; import consulo.virtualFileSystem.VirtualFile; - import jakarta.annotation.Nonnull; +import java.util.Objects; + /** * @author yole */ @ExtensionImpl public class PyIconDescriptorUpdater implements IconDescriptorUpdater { - @Override - public void updateIcon(@Nonnull IconDescriptor iconDescriptor, @Nonnull PsiElement element, int i) { - if (element instanceof PsiDirectory) { - final PsiDirectory directory = (PsiDirectory)element; - if (directory.findFile(PyNames.INIT_DOT_PY) != null) { - final VirtualFile vFile = directory.getVirtualFile(); - final VirtualFile root = ProjectRootManager.getInstance(directory.getProject()).getFileIndex().getSourceRootForFile(vFile); - if (!Comparing.equal(root, vFile)) { - iconDescriptor.setMainIcon(AllIcons.Nodes.Package); - } - } - } - else if (element instanceof PyClass) { - iconDescriptor.setMainIcon(AllIcons.Nodes.Class); - } - else if (element instanceof PyFunction) { - Image icon = null; - final Property property = ((PyFunction)element).getProperty(); - if (property != null) { - if (property.getGetter().valueOrNull() == this) { - icon = PythonIcons.Python.PropertyGetter; - } - else if (property.getSetter().valueOrNull() == this) { - icon = PythonIcons.Python.PropertySetter; + @Override + @RequiredReadAction + public void updateIcon(@Nonnull IconDescriptor iconDescriptor, @Nonnull PsiElement element, int i) { + if (element instanceof PsiDirectory directory) { + if (directory.findFile(PyNames.INIT_DOT_PY) != null) { + VirtualFile vFile = directory.getVirtualFile(); + VirtualFile root = ProjectRootManager.getInstance(directory.getProject()).getFileIndex().getSourceRootForFile(vFile); + if (!Objects.equals(root, vFile)) { + iconDescriptor.setMainIcon(PlatformIconGroup.nodesPackage()); + } + } } - else if (property.getDeleter().valueOrNull() == this) { - icon = PythonIcons.Python.PropertyDeleter; + else if (element instanceof PyClass) { + iconDescriptor.setMainIcon(PlatformIconGroup.nodesClass()); } - else { - icon = AllIcons.Nodes.Property; + else if (element instanceof PyFunction function) { + Image icon = null; + Property property = function.getProperty(); + if (property != null) { + if (property.getGetter().valueOrNull() == this) { + icon = PythonImplIconGroup.pythonPropertygetter(); + } + else if (property.getSetter().valueOrNull() == this) { + icon = PythonImplIconGroup.pythonPropertysetter(); + } + else if (property.getDeleter().valueOrNull() == this) { + icon = PythonImplIconGroup.pythonPropertydeleter(); + } + else { + icon = PlatformIconGroup.nodesProperty(); + } + } + if (icon != null) { + iconDescriptor.setMainIcon(icon); + } + else { + iconDescriptor.setMainIcon(PlatformIconGroup.nodesMethod()); + } } - } - if (icon != null) { - iconDescriptor.setMainIcon(icon); - } - else { - iconDescriptor.setMainIcon(AllIcons.Nodes.Method); - } } - } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/PyLineMarkerProvider.java b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/PyLineMarkerProvider.java index 11678a16..bf7946b8 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/PyLineMarkerProvider.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/PyLineMarkerProvider.java @@ -26,8 +26,8 @@ import com.jetbrains.python.impl.psi.search.PyOverridingMethodsSearch; import com.jetbrains.python.impl.psi.search.PySuperMethodsSearch; import com.jetbrains.python.psi.types.TypeEvalContext; +import consulo.annotation.access.RequiredReadAction; import consulo.annotation.component.ExtensionImpl; -import consulo.application.AllIcons; import consulo.application.util.query.CollectionQuery; import consulo.application.util.query.Query; import consulo.language.Language; @@ -38,10 +38,12 @@ import consulo.language.editor.gutter.LineMarkerProvider; import consulo.language.psi.PsiElement; import consulo.language.psi.util.PsiTreeUtil; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.util.collection.MultiMap; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -52,238 +54,261 @@ @ExtensionImpl public class PyLineMarkerProvider implements LineMarkerProvider, PyLineSeparatorUtil.Provider { - @Nonnull - @Override - public Language getLanguage() { - return PythonLanguage.INSTANCE; - } + @Nonnull + @Override + public Language getLanguage() { + return PythonLanguage.INSTANCE; + } - private static class TooltipProvider implements Function { - private final String myText; + private static class TooltipProvider implements Function { + private final String myText; - private TooltipProvider(String text) { - myText = text; - } + private TooltipProvider(String text) { + myText = text; + } - public String apply(PsiElement psiElement) { - return myText; + @Override + public String apply(PsiElement psiElement) { + return myText; + } } - } - private static final Function ourSubclassTooltipProvider = pyClass -> { - final StringBuilder builder = new StringBuilder("Is subclassed by:"); - final AtomicInteger count = new AtomicInteger(); - PyClassInheritorsSearch.search(pyClass, true).forEach(pyClass1 -> { - if (count.incrementAndGet() >= 10) { - builder.setLength(0); - builder.append("Has subclasses"); - return false; - } - builder.append("
  ").append(pyClass1.getName()); - return true; - }); - return builder.toString(); - }; + private static final Function ourSubclassTooltipProvider = pyClass -> { + StringBuilder builder = new StringBuilder("Is subclassed by:"); + AtomicInteger count = new AtomicInteger(); + PyClassInheritorsSearch.search(pyClass, true).forEach(pyClass1 -> { + if (count.incrementAndGet() >= 10) { + builder.setLength(0); + builder.append("Has subclasses"); + return false; + } + builder.append("
  ").append(pyClass1.getName()); + return true; + }); + return builder.toString(); + }; - private static final Function ourOverridingMethodTooltipProvider = pyFunction -> { - final StringBuilder builder = new StringBuilder("Is overridden in:"); - final AtomicInteger count = new AtomicInteger(); - PyClassInheritorsSearch.search(pyFunction.getContainingClass(), true).forEach(pyClass -> { - if (count.incrementAndGet() >= 10) { - builder.setLength(0); - builder.append("Has overridden methods"); - return false; - } - if (pyClass.findMethodByName(pyFunction.getName(), false, null) != null) { - builder.append("
  ").append(pyClass.getName()); - } - return true; - }); - return builder.toString(); - }; + private static final Function ourOverridingMethodTooltipProvider = pyFunction -> { + StringBuilder builder = new StringBuilder("Is overridden in:"); + AtomicInteger count = new AtomicInteger(); + PyClassInheritorsSearch.search(pyFunction.getContainingClass(), true).forEach(pyClass -> { + if (count.incrementAndGet() >= 10) { + builder.setLength(0); + builder.append("Has overridden methods"); + return false; + } + if (pyClass.findMethodByName(pyFunction.getName(), false, null) != null) { + builder.append("
  ").append(pyClass.getName()); + } + return true; + }); + return builder.toString(); + }; - private static final PyLineMarkerNavigator ourSuperMethodNavigator = new PyLineMarkerNavigator() { - protected String getTitle(final PsiElement elt) { - return "Choose Super Method of " + ((PyFunction)elt.getParent()).getName(); - } + private static final PyLineMarkerNavigator ourSuperMethodNavigator = new PyLineMarkerNavigator<>() { + @Override + @RequiredReadAction + protected String getTitle(PsiElement elt) { + return "Choose Super Method of " + ((PyFunction) elt.getParent()).getName(); + } - @Nullable - protected Query search(final PsiElement elt, @Nonnull final TypeEvalContext context) { - if (!(elt.getParent() instanceof PyFunction)) { - return null; - } - return PySuperMethodsSearch.search((PyFunction)elt.getParent(), context); - } - }; + @Nullable + @Override + protected Query search(PsiElement elt, @Nonnull TypeEvalContext context) { + return elt.getParent() instanceof PyFunction function ? PySuperMethodsSearch.search(function, context) : null; + } + }; - private static final PyLineMarkerNavigator ourSuperAttributeNavigator = new PyLineMarkerNavigator() { - protected String getTitle(final PsiElement elt) { - return "Choose Super Attribute of " + ((PyTargetExpression)elt).getName(); - } + private static final PyLineMarkerNavigator ourSuperAttributeNavigator = new PyLineMarkerNavigator<>() { + @Override + protected String getTitle(PsiElement elt) { + return "Choose Super Attribute of " + ((PyTargetExpression) elt).getName(); + } - @Nullable - protected Query search(final PsiElement elt, @Nonnull final TypeEvalContext context) { - List result = new ArrayList(); - PyClass containingClass = PsiTreeUtil.getParentOfType(elt, PyClass.class); - if (containingClass != null && elt instanceof PyTargetExpression) { - for (PyClass ancestor : containingClass.getAncestorClasses(context)) { - final PyTargetExpression attribute = ancestor.findClassAttribute(((PyTargetExpression)elt).getReferencedName(), false, context); - if (attribute != null) { - result.add(attribute); - } + @Nonnull + @Override + @RequiredReadAction + protected Query search(PsiElement elt, @Nonnull TypeEvalContext context) { + List result = new ArrayList<>(); + PyClass containingClass = PsiTreeUtil.getParentOfType(elt, PyClass.class); + if (containingClass != null && elt instanceof PyTargetExpression targetExpr) { + for (PyClass ancestor : containingClass.getAncestorClasses(context)) { + PyTargetExpression attribute = ancestor.findClassAttribute(targetExpr.getReferencedName(), false, context); + if (attribute != null) { + result.add(attribute); + } + } + } + return new CollectionQuery<>(result); } - } - return new CollectionQuery(result); - } - }; + }; - private static final PyLineMarkerNavigator ourSubclassNavigator = new PyLineMarkerNavigator() { - protected String getTitle(final PyClass elt) { - return "Choose Subclass of " + elt.getName(); - } + private static final PyLineMarkerNavigator ourSubclassNavigator = new PyLineMarkerNavigator<>() { + @Override + @RequiredReadAction + protected String getTitle(PyClass elt) { + return "Choose Subclass of " + elt.getName(); + } - protected Query search(final PyClass elt, @Nonnull TypeEvalContext context) { - return PyClassInheritorsSearch.search(elt, true); - } - }; + @Override + protected Query search(PyClass elt, @Nonnull TypeEvalContext context) { + return PyClassInheritorsSearch.search(elt, true); + } + }; - private static final PyLineMarkerNavigator ourOverridingMethodNavigator = new PyLineMarkerNavigator() { - protected String getTitle(final PyFunction elt) { - return "Choose Overriding Method of " + elt.getName(); - } + private static final PyLineMarkerNavigator ourOverridingMethodNavigator = new PyLineMarkerNavigator<>() { + @Override + @RequiredReadAction + protected String getTitle(PyFunction elt) { + return "Choose Overriding Method of " + elt.getName(); + } - protected Query search(final PyFunction elt, @Nonnull TypeEvalContext context) { - return PyOverridingMethodsSearch.search(elt, true); - } - }; + @Override + protected Query search(PyFunction elt, @Nonnull TypeEvalContext context) { + return PyOverridingMethodsSearch.search(elt, true); + } + }; - public LineMarkerInfo getLineMarkerInfo(@Nonnull final PsiElement element) { - final ASTNode node = element.getNode(); - if (node != null && node.getElementType() == PyTokenTypes.IDENTIFIER && element.getParent() instanceof PyFunction) { - final PyFunction function = (PyFunction)element.getParent(); - return getMethodMarker(element, function); - } - if (element instanceof PyTargetExpression && PyUtil.isClassAttribute(element)) { - return getAttributeMarker((PyTargetExpression)element); - } - if (DaemonCodeAnalyzerSettings.getInstance().SHOW_METHOD_SEPARATORS && isSeparatorAllowed(element)) { - return PyLineSeparatorUtil.addLineSeparatorIfNeeded(this, element); + @Override + @RequiredReadAction + public LineMarkerInfo getLineMarkerInfo(@Nonnull PsiElement element) { + ASTNode node = element.getNode(); + if (node != null && node.getElementType() == PyTokenTypes.IDENTIFIER && element.getParent() instanceof PyFunction function) { + return getMethodMarker(element, function); + } + if (element instanceof PyTargetExpression targetExpr && PyUtil.isClassAttribute(element)) { + return getAttributeMarker(targetExpr); + } + if (DaemonCodeAnalyzerSettings.getInstance().SHOW_METHOD_SEPARATORS && isSeparatorAllowed(element)) { + return PyLineSeparatorUtil.addLineSeparatorIfNeeded(this, element); + } + return null; } - return null; - } - - public boolean isSeparatorAllowed(PsiElement element) { - return element instanceof PyFunction || element instanceof PyClass; - } - @Nullable - private static LineMarkerInfo getMethodMarker(final PsiElement element, final PyFunction function) { - if (PyNames.INIT.equals(function.getName())) { - return null; + @Override + public boolean isSeparatorAllowed(PsiElement element) { + return element instanceof PyFunction || element instanceof PyClass; } - final TypeEvalContext context = - TypeEvalContext.codeAnalysis(element.getProject(), (function != null ? function.getContainingFile() : null)); - final PsiElement superMethod = PySuperMethodsSearch.search(function, context).findFirst(); - if (superMethod != null) { - PyClass superClass = null; - if (superMethod instanceof PyFunction) { - superClass = ((PyFunction)superMethod).getContainingClass(); - } - // TODO: show "implementing" instead of "overriding" icon for Python implementations of Java interface methods - return new LineMarkerInfo(element, - element.getTextRange().getStartOffset(), - AllIcons.Gutter.OverridingMethod, - Pass.LINE_MARKERS, - superClass == null ? null : new - TooltipProvider("Overrides method in " + superClass.getName()), - ourSuperMethodNavigator); - } - return null; - } - @Nullable - private static LineMarkerInfo getAttributeMarker(PyTargetExpression element) { - final String name = element.getReferencedName(); - if (name == null) { - return null; - } - PyClass containingClass = PsiTreeUtil.getParentOfType(element, PyClass.class); - if (containingClass == null) { - return null; - } - for (PyClass ancestor : containingClass.getAncestorClasses(TypeEvalContext.codeAnalysis(element.getProject(), - element.getContainingFile()))) { - final PyTargetExpression ancestorAttr = ancestor.findClassAttribute(name, false, null); - if (ancestorAttr != null) { - return new LineMarkerInfo(element, - element.getTextRange().getStartOffset(), - AllIcons.Gutter.OverridingMethod, - Pass.LINE_MARKERS, - new TooltipProvider("Overrides attribute " + - "in " + ancestor.getName()), - ourSuperAttributeNavigator); - } + @Nullable + @RequiredReadAction + private static LineMarkerInfo getMethodMarker(PsiElement element, PyFunction function) { + if (PyNames.INIT.equals(function.getName())) { + return null; + } + TypeEvalContext context = TypeEvalContext.codeAnalysis(element.getProject(), function.getContainingFile()); + PsiElement superMethod = PySuperMethodsSearch.search(function, context).findFirst(); + if (superMethod != null) { + PyClass superClass = null; + if (superMethod instanceof PyFunction superFunction) { + superClass = superFunction.getContainingClass(); + } + // TODO: show "implementing" instead of "overriding" icon for Python implementations of Java interface methods + return new LineMarkerInfo<>( + element, + element.getTextRange().getStartOffset(), + PlatformIconGroup.gutterOverridingmethod(), + Pass.LINE_MARKERS, + superClass == null ? null : new TooltipProvider("Overrides method in " + superClass.getName()), + ourSuperMethodNavigator + ); + } + return null; } - return null; - } - public void collectSlowLineMarkers(@Nonnull final List elements, @Nonnull final Collection result) { - Set functions = new HashSet(); - for (PsiElement element : elements) { - if (element instanceof PyClass) { - collectInheritingClasses((PyClass)element, result); - } - else if (element instanceof PyFunction) { - functions.add((PyFunction)element); - } + @Nullable + @RequiredReadAction + private static LineMarkerInfo getAttributeMarker(PyTargetExpression element) { + String name = element.getReferencedName(); + if (name == null) { + return null; + } + PyClass containingClass = PsiTreeUtil.getParentOfType(element, PyClass.class); + if (containingClass == null) { + return null; + } + List ancestors = + containingClass.getAncestorClasses(TypeEvalContext.codeAnalysis(element.getProject(), element.getContainingFile())); + for (PyClass ancestor : ancestors) { + PyTargetExpression ancestorAttr = ancestor.findClassAttribute(name, false, null); + if (ancestorAttr != null) { + return new LineMarkerInfo<>( + element, + element.getTextRange().getStartOffset(), + PlatformIconGroup.gutterOverridingmethod(), + Pass.LINE_MARKERS, + new TooltipProvider("Overrides attribute in " + ancestor.getName()), + ourSuperAttributeNavigator + ); + } + } + return null; } - collectOverridingMethods(functions, result); - } - private static void collectInheritingClasses(final PyClass element, final Collection result) { - if (PyClassInheritorsSearch.search(element, false).findFirst() != null) { - result.add(new LineMarkerInfo(element, - element.getTextOffset(), - AllIcons.Gutter.OverridenMethod, - Pass.LINE_MARKERS, - ourSubclassTooltipProvider, - ourSubclassNavigator)); + @Override + @RequiredReadAction + public void collectSlowLineMarkers(@Nonnull List elements, @Nonnull Collection result) { + Set functions = new HashSet<>(); + for (PsiElement element : elements) { + if (element instanceof PyClass pyClass) { + collectInheritingClasses(pyClass, result); + } + else if (element instanceof PyFunction function) { + functions.add(function); + } + } + collectOverridingMethods(functions, result); } - } - private static void collectOverridingMethods(final Set functions, final Collection result) { - Set classes = new HashSet(); - final MultiMap candidates = new MultiMap(); - for (PyFunction function : functions) { - PyClass pyClass = function.getContainingClass(); - if (pyClass != null && function.getName() != null) { - classes.add(pyClass); - candidates.putValue(pyClass, function); - } - } - final Set overridden = new HashSet(); - for (final PyClass pyClass : classes) { - PyClassInheritorsSearch.search(pyClass, true).forEach(inheritor -> { - for (Iterator it = candidates.get(pyClass).iterator(); it.hasNext(); ) { - PyFunction func = it.next(); - if (inheritor.findMethodByName(func.getName(), false, null) != null) { - overridden.add(func); - it.remove(); - } + private static void collectInheritingClasses(PyClass element, Collection result) { + if (PyClassInheritorsSearch.search(element, false).findFirst() != null) { + result.add(new LineMarkerInfo<>( + element, + element.getTextOffset(), + PlatformIconGroup.gutterOverridenmethod(), + Pass.LINE_MARKERS, + ourSubclassTooltipProvider, + ourSubclassNavigator + )); } - return !candidates.isEmpty(); - }); - if (candidates.isEmpty()) { - break; - } } - for (PyFunction func : overridden) { - result.add(new LineMarkerInfo(func, - func.getTextOffset(), - AllIcons.Gutter.OverridenMethod, - Pass.LINE_MARKERS, - ourOverridingMethodTooltipProvider, - ourOverridingMethodNavigator)); + + @RequiredReadAction + private static void collectOverridingMethods(Set functions, Collection result) { + Set classes = new HashSet<>(); + MultiMap candidates = new MultiMap<>(); + for (PyFunction function : functions) { + PyClass pyClass = function.getContainingClass(); + if (pyClass != null && function.getName() != null) { + classes.add(pyClass); + candidates.putValue(pyClass, function); + } + } + Set overridden = new HashSet<>(); + for (PyClass pyClass : classes) { + PyClassInheritorsSearch.search(pyClass, true).forEach(inheritor -> { + for (Iterator it = candidates.get(pyClass).iterator(); it.hasNext(); ) { + PyFunction func = it.next(); + if (inheritor.findMethodByName(func.getName(), false, null) != null) { + overridden.add(func); + it.remove(); + } + } + return !candidates.isEmpty(); + }); + if (candidates.isEmpty()) { + break; + } + } + for (PyFunction func : overridden) { + result.add(new LineMarkerInfo<>( + func, + func.getTextOffset(), + PlatformIconGroup.gutterOverridenmethod(), + Pass.LINE_MARKERS, + ourOverridingMethodTooltipProvider, + ourOverridingMethodNavigator + )); + } } - } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyDictKeyNamesCompletionContributor.java b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyDictKeyNamesCompletionContributor.java index 6cfa3dbd..5b4034c6 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyDictKeyNamesCompletionContributor.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyDictKeyNamesCompletionContributor.java @@ -21,230 +21,224 @@ import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; +import consulo.annotation.access.RequiredReadAction; import consulo.annotation.component.ExtensionImpl; -import consulo.application.AllIcons; import consulo.document.Document; import consulo.document.util.TextRange; import consulo.language.Language; import consulo.language.ast.ASTNode; -import consulo.language.editor.completion.*; -import consulo.language.editor.completion.lookup.InsertHandler; -import consulo.language.editor.completion.lookup.InsertionContext; -import consulo.language.editor.completion.lookup.LookupElement; +import consulo.language.editor.completion.CompletionContributor; +import consulo.language.editor.completion.CompletionResultSet; +import consulo.language.editor.completion.CompletionType; import consulo.language.editor.completion.lookup.LookupElementBuilder; import consulo.language.psi.PsiElement; import consulo.language.psi.PsiFile; import consulo.language.psi.PsiReference; import consulo.language.psi.util.PsiTreeUtil; -import consulo.language.util.ProcessingContext; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.util.lang.StringUtil; - import jakarta.annotation.Nonnull; import static consulo.language.pattern.PlatformPatterns.psiElement; /** - * User: catherine - *

* Complete known keys for dictionaries + * + * @author catherine */ @ExtensionImpl public class PyDictKeyNamesCompletionContributor extends CompletionContributor { - public PyDictKeyNamesCompletionContributor() { - extend(CompletionType.BASIC, psiElement().inside(PySubscriptionExpression.class), new CompletionProvider() { - @Override - public void addCompletions(@Nonnull final CompletionParameters parameters, - final ProcessingContext context, - @Nonnull final CompletionResultSet result) { - final PsiElement original = parameters.getOriginalPosition(); - final int offset = parameters.getOffset(); - if (original == null) { - return; - } - final CompletionResultSet dictCompletion = createResult(original, result, offset); + public PyDictKeyNamesCompletionContributor() { + extend( + CompletionType.BASIC, + psiElement().inside(PySubscriptionExpression.class), + (parameters, context, result) -> { + PsiElement original = parameters.getOriginalPosition(); + int offset = parameters.getOffset(); + if (original == null) { + return; + } + CompletionResultSet dictCompletion = createResult(original, result, offset); - PySubscriptionExpression subscription = PsiTreeUtil.getParentOfType(original, PySubscriptionExpression.class); - if (subscription == null) { - return; - } - PsiElement operand = subscription.getOperand(); - PsiReference reference = operand.getReference(); - if (reference != null) { - PsiElement resolvedElement = reference.resolve(); - if (resolvedElement instanceof PyTargetExpression) { - PyDictLiteralExpression dict = PsiTreeUtil.getNextSiblingOfType(resolvedElement, PyDictLiteralExpression.class); - if (dict != null) { - addDictLiteralKeys(dict, dictCompletion); - PsiFile file = parameters.getOriginalFile(); - addAdditionalKeys(file, operand, dictCompletion); + PySubscriptionExpression subscription = PsiTreeUtil.getParentOfType(original, PySubscriptionExpression.class); + if (subscription == null) { + return; + } + PsiElement operand = subscription.getOperand(); + PsiReference reference = operand.getReference(); + if (reference != null && reference.resolve() instanceof PyTargetExpression targetExpr) { + PyDictLiteralExpression dict = PsiTreeUtil.getNextSiblingOfType(targetExpr, PyDictLiteralExpression.class); + if (dict != null) { + addDictLiteralKeys(dict, dictCompletion); + PsiFile file = parameters.getOriginalFile(); + addAdditionalKeys(file, operand, dictCompletion); + } + PyCallExpression dictConstructor = PsiTreeUtil.getNextSiblingOfType(targetExpr, PyCallExpression.class); + if (dictConstructor != null) { + addDictConstructorKeys(dictConstructor, dictCompletion); + PsiFile file = parameters.getOriginalFile(); + addAdditionalKeys(file, operand, dictCompletion); + } + } } - PyCallExpression dictConstructor = PsiTreeUtil.getNextSiblingOfType(resolvedElement, PyCallExpression.class); - if (dictConstructor != null) { - addDictConstructorKeys(dictConstructor, dictCompletion); - PsiFile file = parameters.getOriginalFile(); - addAdditionalKeys(file, operand, dictCompletion); + ); + } + + /** + * create completion result with prefix matcher if needed + * + * @param original is original element + * @param result is initial completion result + * @param offset + * @return + */ + @RequiredReadAction + private static CompletionResultSet createResult(@Nonnull PsiElement original, @Nonnull CompletionResultSet result, int offset) { + PyStringLiteralExpression prevElement = PsiTreeUtil.getPrevSiblingOfType(original, PyStringLiteralExpression.class); + if (prevElement != null) { + ASTNode prevNode = prevElement.getNode(); + if (prevNode != null) { + if (prevNode.getElementType() != PyTokenTypes.LBRACKET) { + return result.withPrefixMatcher(findPrefix(prevElement, offset)); + } } - } } - } - }); - } - - /** - * create completion result with prefix matcher if needed - * - * @param original is original element - * @param result is initial completion result - * @param offset - * @return - */ - private static CompletionResultSet createResult(@Nonnull final PsiElement original, - @Nonnull final CompletionResultSet result, - final int offset) { - PyStringLiteralExpression prevElement = PsiTreeUtil.getPrevSiblingOfType(original, PyStringLiteralExpression.class); - if (prevElement != null) { - ASTNode prevNode = prevElement.getNode(); - if (prevNode != null) { - if (prevNode.getElementType() != PyTokenTypes.LBRACKET) { - return result.withPrefixMatcher(findPrefix(prevElement, offset)); + if (original.getParent() instanceof PyStringLiteralExpression stringLiteralExpr) { + return result.withPrefixMatcher(findPrefix(stringLiteralExpr, offset)); } - } - } - final PsiElement parentElement = original.getParent(); - if (parentElement != null) { - if (parentElement instanceof PyStringLiteralExpression) { - return result.withPrefixMatcher(findPrefix((PyElement)parentElement, offset)); - } - } - final PyNumericLiteralExpression number = - PsiTreeUtil.findElementOfClassAtOffset(original.getContainingFile(), offset - 1, PyNumericLiteralExpression.class, false); - if (number != null) { - return result.withPrefixMatcher(findPrefix(number, offset)); + PyNumericLiteralExpression number = + PsiTreeUtil.findElementOfClassAtOffset(original.getContainingFile(), offset - 1, PyNumericLiteralExpression.class, false); + if (number != null) { + return result.withPrefixMatcher(findPrefix(number, offset)); + } + return result; } - return result; - } - - /** - * finds prefix. For *'str'* returns just *'str*. - * - * @param element to find prefix of - * @return prefix - */ - private static String findPrefix(final PyElement element, final int offset) { - return TextRange.create(element.getTextRange().getStartOffset(), offset).substring(element.getContainingFile().getText()); - } - /** - * add keys to completion result from dict constructor - */ - private static void addDictConstructorKeys(final PyCallExpression dictConstructor, final CompletionResultSet result) { - final PyExpression callee = dictConstructor.getCallee(); - if (callee == null) { - return; + /** + * finds prefix. For *'str'* returns just *'str*. + * + * @param element to find prefix of + * @return prefix + */ + @RequiredReadAction + private static String findPrefix(PyElement element, int offset) { + return TextRange.create(element.getTextRange().getStartOffset(), offset).substring(element.getContainingFile().getText()); } - final String name = callee.getText(); - if ("dict".equals(name)) { - final TypeEvalContext context = TypeEvalContext.codeCompletion(callee.getProject(), callee.getContainingFile()); - final PyType type = context.getType(dictConstructor); - if (type != null && type.isBuiltin()) { - final PyArgumentList list = dictConstructor.getArgumentList(); - if (list == null) { - return; + + /** + * add keys to completion result from dict constructor + */ + @RequiredReadAction + private static void addDictConstructorKeys(PyCallExpression dictConstructor, CompletionResultSet result) { + PyExpression callee = dictConstructor.getCallee(); + if (callee == null) { + return; } - final PyExpression[] argumentList = list.getArguments(); - for (final PyExpression argument : argumentList) { - if (argument instanceof PyKeywordArgument) { - result.addElement(createElement("'" + ((PyKeywordArgument)argument).getKeyword() + "'")); - } + String name = callee.getText(); + if ("dict".equals(name)) { + TypeEvalContext context = TypeEvalContext.codeCompletion(callee.getProject(), callee.getContainingFile()); + PyType type = context.getType(dictConstructor); + if (type != null && type.isBuiltin()) { + PyArgumentList list = dictConstructor.getArgumentList(); + if (list == null) { + return; + } + PyExpression[] argumentList = list.getArguments(); + for (PyExpression argument : argumentList) { + if (argument instanceof PyKeywordArgument) { + result.addElement(createElement("'" + ((PyKeywordArgument) argument).getKeyword() + "'")); + } + } + } } - } } - } - /** - * add keys from assignment statements - * For instance, dictionary['b']=b - * - * @param file to get additional keys - * @param operand is operand of origin element - * @param result is completion result set - */ - private static void addAdditionalKeys(final PsiFile file, final PsiElement operand, final CompletionResultSet result) { - PySubscriptionExpression[] subscriptionExpressions = PyUtil.getAllChildrenOfType(file, PySubscriptionExpression.class); - for (PySubscriptionExpression expr : subscriptionExpressions) { - if (expr.getOperand().getText().equals(operand.getText())) { - final PsiElement parent = expr.getParent(); - if (parent instanceof PyAssignmentStatement) { - if (expr.equals(((PyAssignmentStatement)parent).getLeftHandSideExpression())) { - PyExpression key = expr.getIndexExpression(); - if (key != null) { - boolean addHandler = PsiTreeUtil.findElementOfClassAtRange(file, - key.getTextRange().getStartOffset(), - key.getTextRange().getEndOffset(), - PyStringLiteralExpression.class) - != null; - result.addElement(createElement(key.getText(), addHandler)); + /** + * add keys from assignment statements + * For instance, dictionary['b']=b + * + * @param file to get additional keys + * @param operand is operand of origin element + * @param result is completion result set + */ + @RequiredReadAction + private static void addAdditionalKeys(PsiFile file, PsiElement operand, CompletionResultSet result) { + PySubscriptionExpression[] subscriptionExpressions = PyUtil.getAllChildrenOfType(file, PySubscriptionExpression.class); + for (PySubscriptionExpression expr : subscriptionExpressions) { + if (expr.getOperand().getText().equals(operand.getText()) + && expr.getParent() instanceof PyAssignmentStatement assignment + && expr.equals(assignment.getLeftHandSideExpression())) { + PyExpression key = expr.getIndexExpression(); + if (key != null) { + boolean addHandler = PsiTreeUtil.findElementOfClassAtRange( + file, + key.getTextRange().getStartOffset(), + key.getTextRange().getEndOffset(), + PyStringLiteralExpression.class + ) != null; + result.addElement(createElement(key.getText(), addHandler)); + } } - } } - } } - } - /** - * add keys from dict literal expression - */ - public static void addDictLiteralKeys(final PyDictLiteralExpression dict, final CompletionResultSet result) { - PyKeyValueExpression[] keyValues = dict.getElements(); - for (PyKeyValueExpression expression : keyValues) { - boolean addHandler = PsiTreeUtil.findElementOfClassAtRange(dict.getContainingFile(), - expression.getTextRange().getStartOffset(), - expression.getTextRange().getEndOffset(), - PyStringLiteralExpression.class) != null; - result.addElement(createElement(expression.getKey().getText(), addHandler)); + /** + * add keys from dict literal expression + */ + @RequiredReadAction + public static void addDictLiteralKeys(PyDictLiteralExpression dict, CompletionResultSet result) { + PyKeyValueExpression[] keyValues = dict.getElements(); + for (PyKeyValueExpression expression : keyValues) { + boolean addHandler = PsiTreeUtil.findElementOfClassAtRange( + dict.getContainingFile(), + expression.getTextRange().getStartOffset(), + expression.getTextRange().getEndOffset(), + PyStringLiteralExpression.class + ) != null; + result.addElement(createElement(expression.getKey().getText(), addHandler)); + } } - } - private static LookupElementBuilder createElement(final String key) { - return createElement(key, true); - } + private static LookupElementBuilder createElement(String key) { + return createElement(key, true); + } - private static LookupElementBuilder createElement(final String key, final boolean addHandler) { - LookupElementBuilder item; - item = LookupElementBuilder.create(key).withTypeText("dict key").withIcon(AllIcons.Nodes.Parameter); + private static LookupElementBuilder createElement(String key, boolean addHandler) { + LookupElementBuilder item; + item = LookupElementBuilder.create(key).withTypeText("dict key").withIcon(PlatformIconGroup.nodesParameter()); - if (addHandler) { - item = item.withInsertHandler(new InsertHandler() { - @Override - public void handleInsert(final InsertionContext context, final LookupElement item) { - final PyStringLiteralExpression str = - PsiTreeUtil.findElementOfClassAtOffset(context.getFile(), context.getStartOffset(), PyStringLiteralExpression.class, false); - if (str != null) { - final boolean isDictKeys = PsiTreeUtil.getParentOfType(str, PySubscriptionExpression.class) != null; - if (isDictKeys) { - final int off = context.getStartOffset() + str.getTextLength(); - final PsiElement element = context.getFile().findElementAt(off); - final boolean atRBrace = element == null || element.getNode().getElementType() == PyTokenTypes.RBRACKET; - final boolean badQuoting = (!StringUtil.startsWithChar(str.getText(), '\'') || !StringUtil.endsWithChar(str.getText(), - '\'')) && (!StringUtil - .startsWithChar(str.getText() - , '"') || !StringUtil.endsWithChar(str.getText(), '"')); - if (badQuoting || !atRBrace) { - final Document document = context.getEditor().getDocument(); - final int offset = context.getTailOffset(); - document.deleteString(offset - 1, offset); - } - } - } + if (addHandler) { + item = item.withInsertHandler((context, item1) -> { + PyStringLiteralExpression str = PsiTreeUtil.findElementOfClassAtOffset( + context.getFile(), + context.getStartOffset(), + PyStringLiteralExpression.class, + false + ); + if (str != null) { + boolean isDictKeys = PsiTreeUtil.getParentOfType(str, PySubscriptionExpression.class) != null; + if (isDictKeys) { + int off = context.getStartOffset() + str.getTextLength(); + PsiElement element = context.getFile().findElementAt(off); + boolean atRBrace = element == null || element.getNode().getElementType() == PyTokenTypes.RBRACKET; + String text = str.getText(); + boolean badQuoting = !(StringUtil.startsWithChar(text, '\'') && StringUtil.endsWithChar(text, '\'')) + && !(StringUtil.startsWithChar(text, '"') && StringUtil.endsWithChar(text, '"')); + if (badQuoting || !atRBrace) { + Document document = context.getEditor().getDocument(); + int offset = context.getTailOffset(); + document.deleteString(offset - 1, offset); + } + } + } + }); } - }); + return item; } - return item; - } - @Nonnull - @Override - public Language getLanguage() { - return PythonLanguage.INSTANCE; - } + @Nonnull + @Override + public Language getLanguage() { + return PythonLanguage.INSTANCE; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyParameterCompletionContributor.java b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyParameterCompletionContributor.java index 915d9be1..94b2fb2f 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyParameterCompletionContributor.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/completion/PyParameterCompletionContributor.java @@ -13,19 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.jetbrains.python.impl.codeInsight.completion; import com.jetbrains.python.PythonLanguage; import com.jetbrains.python.psi.PyParameterList; +import consulo.annotation.access.RequiredReadAction; import consulo.annotation.component.ExtensionImpl; -import consulo.application.AllIcons; import consulo.language.Language; import consulo.language.editor.completion.*; import consulo.language.editor.completion.lookup.LookupElementBuilder; import consulo.language.util.ProcessingContext; -import consulo.ui.image.Image; - +import consulo.platform.base.icon.PlatformIconGroup; import jakarta.annotation.Nonnull; import static consulo.language.pattern.PlatformPatterns.psiElement; @@ -35,33 +33,40 @@ */ @ExtensionImpl public class PyParameterCompletionContributor extends CompletionContributor { - public PyParameterCompletionContributor() { - extend(CompletionType.BASIC, - psiElement().inside(PyParameterList.class).afterLeaf("*"), - new ParameterCompletionProvider("args")); - extend(CompletionType.BASIC, - psiElement().inside(PyParameterList.class).afterLeaf("**"), - new ParameterCompletionProvider("kwargs")); - } + public PyParameterCompletionContributor() { + extend( + CompletionType.BASIC, + psiElement().inside(PyParameterList.class).afterLeaf("*"), + new ParameterCompletionProvider("args") + ); + extend( + CompletionType.BASIC, + psiElement().inside(PyParameterList.class).afterLeaf("**"), + new ParameterCompletionProvider("kwargs") + ); + } - @Nonnull - @Override - public Language getLanguage() { - return PythonLanguage.INSTANCE; - } + @Nonnull + @Override + public Language getLanguage() { + return PythonLanguage.INSTANCE; + } - private static class ParameterCompletionProvider implements CompletionProvider { - private String myName; + private static class ParameterCompletionProvider implements CompletionProvider { + private String myName; - private ParameterCompletionProvider(String name) { - myName = name; - } + private ParameterCompletionProvider(String name) { + myName = name; + } - @Override - public void addCompletions(@Nonnull CompletionParameters parameters, - ProcessingContext context, - @Nonnull CompletionResultSet result) { - result.addElement(LookupElementBuilder.create(myName).withIcon((Image)AllIcons.Nodes.Parameter)); + @Override + @RequiredReadAction + public void addCompletions( + @Nonnull CompletionParameters parameters, + ProcessingContext context, + @Nonnull CompletionResultSet result + ) { + result.addElement(LookupElementBuilder.create(myName).withIcon(PlatformIconGroup.nodesParameter())); + } } - } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/userSkeletons/PyUserSkeletonsLineMarkerProvider.java b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/userSkeletons/PyUserSkeletonsLineMarkerProvider.java index 80dcb438..a491ca24 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/userSkeletons/PyUserSkeletonsLineMarkerProvider.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/codeInsight/userSkeletons/PyUserSkeletonsLineMarkerProvider.java @@ -19,20 +19,19 @@ import com.jetbrains.python.psi.PyElement; import com.jetbrains.python.psi.PyFunction; import com.jetbrains.python.psi.PyTargetExpression; +import consulo.annotation.access.RequiredReadAction; import consulo.annotation.component.ExtensionImpl; -import consulo.application.AllIcons; import consulo.codeEditor.markup.GutterIconRenderer; import consulo.language.Language; import consulo.language.editor.Pass; -import consulo.language.editor.gutter.GutterIconNavigationHandler; import consulo.language.editor.gutter.LineMarkerInfo; import consulo.language.editor.gutter.LineMarkerProvider; import consulo.language.psi.PsiElement; import consulo.language.psi.util.PsiNavigateUtil; - +import consulo.platform.base.icon.PlatformIconGroup; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; -import java.awt.event.MouseEvent; + import java.util.Collection; import java.util.List; @@ -41,47 +40,48 @@ */ @ExtensionImpl public class PyUserSkeletonsLineMarkerProvider implements LineMarkerProvider { - @Nullable - @Override - public LineMarkerInfo getLineMarkerInfo(@Nonnull PsiElement element) { - return null; - } + @Nullable + @Override + @RequiredReadAction + public LineMarkerInfo getLineMarkerInfo(@Nonnull PsiElement element) { + return null; + } - @Override - public void collectSlowLineMarkers(@Nonnull List elements, @Nonnull Collection result) { - for (PsiElement element : elements) { - final PyElement skeleton = getUserSkeleton(element); - if (skeleton != null) { - result.add(new LineMarkerInfo(element, - element.getTextRange(), - AllIcons.Gutter.Unique, - Pass.LINE_MARKERS, - e -> "Has user skeleton", - new GutterIconNavigationHandler() { - @Override - public void navigate(MouseEvent e, PsiElement elt) { - final PyElement s = getUserSkeleton(elt); - if (s != null) { - PsiNavigateUtil.navigate(s); - } - } - }, - GutterIconRenderer.Alignment.RIGHT)); - } + @Override + @RequiredReadAction + public void collectSlowLineMarkers(@Nonnull List elements, @Nonnull Collection result) { + for (PsiElement element : elements) { + PyElement skeleton = getUserSkeleton(element); + if (skeleton != null) { + result.add(new LineMarkerInfo<>( + element, + element.getTextRange(), + PlatformIconGroup.gutterUnique(), + Pass.LINE_MARKERS, + e -> "Has user skeleton", + (e, elt) -> { + PyElement s = getUserSkeleton(elt); + if (s != null) { + PsiNavigateUtil.navigate(s); + } + }, + GutterIconRenderer.Alignment.RIGHT + )); + } + } } - } - @Nullable - private static PyElement getUserSkeleton(@Nonnull PsiElement element) { - if (element instanceof PyFunction || element instanceof PyTargetExpression) { - return PyUserSkeletonsUtil.getUserSkeleton((PyElement)element); + @Nullable + private static PyElement getUserSkeleton(@Nonnull PsiElement element) { + if (element instanceof PyFunction || element instanceof PyTargetExpression) { + return PyUserSkeletonsUtil.getUserSkeleton((PyElement) element); + } + return null; } - return null; - } - @Nonnull - @Override - public Language getLanguage() { - return PythonLanguage.INSTANCE; - } + @Nonnull + @Override + public Language getLanguage() { + return PythonLanguage.INSTANCE; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/documentation/doctest/PyDocstringLanguageInjector.java b/python-impl/src/main/java/com/jetbrains/python/impl/documentation/doctest/PyDocstringLanguageInjector.java index 8826e922..82bb14d0 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/documentation/doctest/PyDocstringLanguageInjector.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/documentation/doctest/PyDocstringLanguageInjector.java @@ -24,102 +24,106 @@ import consulo.language.inject.InjectedLanguagePlaces; import consulo.language.inject.LanguageInjector; import consulo.language.psi.PsiLanguageInjectionHost; -import consulo.language.util.ModuleUtilCore; import consulo.module.Module; import consulo.util.lang.Pair; import consulo.util.lang.StringUtil; import jakarta.annotation.Nonnull; + import java.util.List; /** - * User: ktisha + * @author ktisha */ @ExtensionImpl public class PyDocstringLanguageInjector implements LanguageInjector { - @Override - public void injectLanguages(@Nonnull final PsiLanguageInjectionHost host, - @Nonnull final InjectedLanguagePlaces injectionPlacesRegistrar) { - if (!(host instanceof PyStringLiteralExpression)) { - return; - } - final Module module = ModuleUtilCore.findModuleForPsiElement(host); - if (module == null || !PyDocumentationSettings.getInstance(module).isAnalyzeDoctest()) { - return; - } + @Override + public void injectLanguages( + @Nonnull PsiLanguageInjectionHost host, + @Nonnull InjectedLanguagePlaces injectionPlacesRegistrar + ) { + if (!(host instanceof PyStringLiteralExpression)) { + return; + } + final Module module = host.getModule(); + if (module == null || !PyDocumentationSettings.getInstance(module).isAnalyzeDoctest()) { + return; + } - if (DocStringUtil.getParentDefinitionDocString(host) == host) { - int start = 0; - int end = host.getTextLength() - 1; - final String text = host.getText(); + if (DocStringUtil.getParentDefinitionDocString(host) == host) { + int start = 0; + int end = host.getTextLength() - 1; + final String text = host.getText(); - final Pair quotes = PyStringLiteralUtil.getQuotes(text); - final List strings = StringUtil.split(text, "\n", false); + final Pair quotes = PyStringLiteralUtil.getQuotes(text); + final List strings = StringUtil.split(text, "\n", false); - boolean gotExample = false; + boolean gotExample = false; - int currentPosition = 0; - int maxPosition = text.length(); - boolean endsWithSlash = false; - for (String string : strings) { - final String trimmedString = string.trim(); - if (!trimmedString.startsWith(">>>") && !trimmedString.startsWith("...") && gotExample && start < end) { - gotExample = false; - if (!endsWithSlash) { - injectionPlacesRegistrar.addPlace(PyDocstringLanguageDialect.getInstance(), TextRange.create(start, end), null, null); - } - } - final String closingQuote = quotes == null ? text.substring(0, 1) : quotes.second; + int currentPosition = 0; + int maxPosition = text.length(); + boolean endsWithSlash = false; + for (String string : strings) { + final String trimmedString = string.trim(); + if (!trimmedString.startsWith(">>>") && !trimmedString.startsWith("...") && gotExample && start < end) { + gotExample = false; + if (!endsWithSlash) { + injectionPlacesRegistrar.addPlace(PyDocstringLanguageDialect.INSTANCE, TextRange.create(start, end), null, null); + } + } + final String closingQuote = quotes == null ? text.substring(0, 1) : quotes.second; - if (endsWithSlash && !trimmedString.endsWith("\\")) { - endsWithSlash = false; - injectionPlacesRegistrar.addPlace(PyDocstringLanguageDialect.getInstance(), - TextRange.create(start, getEndOffset(currentPosition, string, maxPosition, closingQuote)), - null, - null); - } + if (endsWithSlash && !trimmedString.endsWith("\\")) { + endsWithSlash = false; + injectionPlacesRegistrar.addPlace( + PyDocstringLanguageDialect.INSTANCE, + TextRange.create(start, getEndOffset(currentPosition, string, maxPosition, closingQuote)), + null, + null + ); + } - if (trimmedString.startsWith(">>>")) { - if (trimmedString.endsWith("\\")) { - endsWithSlash = true; - } + if (trimmedString.startsWith(">>>")) { + if (trimmedString.endsWith("\\")) { + endsWithSlash = true; + } - if (!gotExample) { - start = currentPosition; - } + if (!gotExample) { + start = currentPosition; + } - gotExample = true; - end = getEndOffset(currentPosition, string, maxPosition, closingQuote); - } - else if (trimmedString.startsWith("...") && gotExample) { - if (trimmedString.endsWith("\\")) { - endsWithSlash = true; - } + gotExample = true; + end = getEndOffset(currentPosition, string, maxPosition, closingQuote); + } + else if (trimmedString.startsWith("...") && gotExample) { + if (trimmedString.endsWith("\\")) { + endsWithSlash = true; + } - end = getEndOffset(currentPosition, string, maxPosition, closingQuote); + end = getEndOffset(currentPosition, string, maxPosition, closingQuote); + } + currentPosition += string.length(); + } + if (gotExample && start < end) { + injectionPlacesRegistrar.addPlace(PyDocstringLanguageDialect.INSTANCE, TextRange.create(start, end), null, null); + } } - currentPosition += string.length(); - } - if (gotExample && start < end) { - injectionPlacesRegistrar.addPlace(PyDocstringLanguageDialect.getInstance(), TextRange.create(start, end), null, null); - } } - } - private static int getEndOffset(int start, String s, int maxPosition, String closingQuote) { - int end; - int length = s.length(); - if (s.trim().endsWith(closingQuote)) { - length -= 3; - } - else if (start + length == maxPosition && (s.trim().endsWith("\"") || s.trim().endsWith("'"))) { - length -= 1; - } + private static int getEndOffset(int start, String s, int maxPosition, String closingQuote) { + int end; + int length = s.length(); + if (s.trim().endsWith(closingQuote)) { + length -= 3; + } + else if (start + length == maxPosition && (s.trim().endsWith("\"") || s.trim().endsWith("'"))) { + length -= 1; + } - end = start + length; - if (s.endsWith("\n")) { - end -= 1; + end = start + length; + if (s.endsWith("\n")) { + end -= 1; + } + return end; } - return end; - } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/packaging/PyPackageManagerUI.java b/python-impl/src/main/java/com/jetbrains/python/impl/packaging/PyPackageManagerUI.java index 4802cf98..16d65516 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/packaging/PyPackageManagerUI.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/packaging/PyPackageManagerUI.java @@ -20,15 +20,14 @@ import com.jetbrains.python.packaging.PyPackageManager; import com.jetbrains.python.packaging.PyPackageManagers; import com.jetbrains.python.packaging.PyRequirement; -import consulo.application.AllIcons; import consulo.application.Application; -import consulo.application.ApplicationManager; import consulo.application.progress.ProgressIndicator; import consulo.application.progress.ProgressManager; import consulo.application.progress.Task; import consulo.content.bundle.Sdk; import consulo.execution.RunCanceledByUserException; import consulo.logging.Logger; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.process.ExecutionException; import consulo.project.Project; import consulo.project.ui.notification.Notification; @@ -37,348 +36,360 @@ import consulo.project.ui.notification.NotificationsManager; import consulo.project.ui.notification.event.NotificationListener; import consulo.repository.ui.PackageManagementService; +import consulo.ui.annotation.RequiredUIAccess; import consulo.ui.ex.awt.Messages; import consulo.util.lang.StringUtil; -import consulo.util.lang.ref.Ref; - +import consulo.util.lang.ref.SimpleReference; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; -import javax.swing.event.HyperlinkEvent; + import java.util.*; /** * @author vlan */ public class PyPackageManagerUI { - @Nonnull - private static final Logger LOG = Logger.getInstance(PyPackageManagerUI.class); - - @Nullable - private Listener myListener; - @Nonnull - private Project myProject; - @Nonnull - private Sdk mySdk; - - public interface Listener { - void started(); - - void finished(List exceptions); - } - - public PyPackageManagerUI(@Nonnull Project project, @Nonnull Sdk sdk, @Nullable Listener listener) { - myProject = project; - mySdk = sdk; - myListener = listener; - } - - public void installManagement() { - ProgressManager.getInstance().run(new InstallManagementTask(myProject, mySdk, myListener)); - } - - public void install(@Nonnull final List requirements, @Nonnull final List extraArgs) { - ProgressManager.getInstance().run(new InstallTask(myProject, mySdk, requirements, extraArgs, myListener)); - } - - public void uninstall(@Nonnull final List packages) { - if (checkDependents(packages)) { - return; - } - ProgressManager.getInstance().run(new UninstallTask(myProject, mySdk, myListener, packages)); - } - - private boolean checkDependents(@Nonnull final List packages) { - try { - final Map> dependentPackages = collectDependents(packages, mySdk); - final int[] warning = {0}; - if (!dependentPackages.isEmpty()) { - Application application = ApplicationManager.getApplication(); - application.invokeAndWait(() -> { - if (dependentPackages.size() == 1) { - String message = "You are attempting to uninstall "; - List dep = new ArrayList<>(); - int size = 1; - for (Map.Entry> entry : dependentPackages.entrySet()) { - final Set value = entry.getValue(); - size = value.size(); - dep.add(entry.getKey() + " package which is required for " + StringUtil.join(value, ", ")); - } - message += StringUtil.join(dep, "\n"); - message += size == 1 ? " package" : " packages"; - message += "\n\nDo you want to proceed?"; - warning[0] = Messages.showYesNoDialog(message, "Warning", AllIcons.General.BalloonWarning); - } - else { - String message = "You are attempting to uninstall packages which are required for another packages.\n\n"; - List dep = new ArrayList<>(); - for (Map.Entry> entry : dependentPackages.entrySet()) { - dep.add(entry.getKey() + " -> " + StringUtil.join(entry.getValue(), ", ")); - } - message += StringUtil.join(dep, "\n"); - message += "\n\nDo you want to proceed?"; - warning[0] = Messages.showYesNoDialog(message, "Warning", AllIcons.General.BalloonWarning); - } - }, application.getCurrentModalityState()); - } - if (warning[0] != Messages.YES) { - return true; - } - } - catch (ExecutionException e) { - LOG.info("Error loading packages dependents: " + e.getMessage(), e); - } - return false; - } - - private static Map> collectDependents(@Nonnull final List packages, Sdk sdk) throws ExecutionException { - Map> dependentPackages = new HashMap<>(); - for (PyPackage pkg : packages) { - final Set dependents = PyPackageManager.getInstance(sdk).getDependents(pkg); - if (dependents != null && !dependents.isEmpty()) { - for (PyPackage dependent : dependents) { - if (!packages.contains(dependent)) { - dependentPackages.put(pkg.getName(), dependents); - } - } - } - } - return dependentPackages; - } - - private abstract static class PackagingTask extends Task.Backgroundable { - private static final NotificationGroup PACKAGING_GROUP_ID = NotificationGroup.balloonGroup("Python Packaging"); - @Nonnull - protected final Sdk mySdk; + private static final Logger LOG = Logger.getInstance(PyPackageManagerUI.class); + @Nullable - protected final Listener myListener; + private Listener myListener; + @Nonnull + private Project myProject; + @Nonnull + private Sdk mySdk; - public PackagingTask(@Nullable Project project, @Nonnull Sdk sdk, @Nonnull String title, @Nullable Listener listener) { - super(project, title); - mySdk = sdk; - myListener = listener; - } + public interface Listener { + void started(); - @Override - public void run(@Nonnull ProgressIndicator indicator) { - taskStarted(indicator); - taskFinished(runTask(indicator)); + void finished(List exceptions); } - @Nonnull - protected abstract List runTask(@Nonnull ProgressIndicator indicator); + public PyPackageManagerUI(@Nonnull Project project, @Nonnull Sdk sdk, @Nullable Listener listener) { + myProject = project; + mySdk = sdk; + myListener = listener; + } - @Nonnull - protected abstract String getSuccessTitle(); + public void installManagement() { + ProgressManager.getInstance().run(new InstallManagementTask(myProject, mySdk, myListener)); + } - @Nonnull - protected abstract String getSuccessDescription(); + public void install(@Nonnull List requirements, @Nonnull List extraArgs) { + ProgressManager.getInstance().run(new InstallTask(myProject, mySdk, requirements, extraArgs, myListener)); + } - @Nonnull - protected abstract String getFailureTitle(); - - protected void taskStarted(@Nonnull ProgressIndicator indicator) { - final PackagingNotification[] notifications = - NotificationsManager.getNotificationsManager().getNotificationsOfType(PackagingNotification.class, (Project)getProject()); - for (PackagingNotification notification : notifications) { - notification.expire(); - } - indicator.setText(getTitle() + "..."); - if (myListener != null) { - ApplicationManager.getApplication().invokeLater(() -> myListener.started()); - } + @RequiredUIAccess + public void uninstall(@Nonnull List packages) { + if (checkDependents(packages)) { + return; + } + ProgressManager.getInstance().run(new UninstallTask(myProject, mySdk, myListener, packages)); } - protected void taskFinished(@Nonnull final List exceptions) { - final Ref notificationRef = new Ref<>(null); - if (exceptions.isEmpty()) { - notificationRef.set(new PackagingNotification(PACKAGING_GROUP_ID, - getSuccessTitle(), - getSuccessDescription(), - NotificationType.INFORMATION, - null)); - } - else { - final PackageManagementService.ErrorDescription description = PyPackageManagementService.toErrorDescription(exceptions, mySdk); - if (description != null) { - final String firstLine = getTitle() + ": error occurred."; - final NotificationListener listener = new NotificationListener() { - @Override - public void hyperlinkUpdate(@Nonnull Notification notification, @Nonnull HyperlinkEvent event) { - assert myProject != null; - final String title = StringUtil.capitalizeWords(getFailureTitle(), true); - consulo.ide.impl.idea.webcore.packaging.PackagesNotificationPanel.showError(title, description); + @RequiredUIAccess + private boolean checkDependents(@Nonnull List packages) { + try { + Map> dependentPackages = collectDependents(packages, mySdk); + int[] warning = {0}; + if (!dependentPackages.isEmpty()) { + Application application = myProject.getApplication(); + application.invokeAndWait( + () -> { + if (dependentPackages.size() == 1) { + String message = "You are attempting to uninstall "; + List dep = new ArrayList<>(); + int size = 1; + for (Map.Entry> entry : dependentPackages.entrySet()) { + Set value = entry.getValue(); + size = value.size(); + dep.add(entry.getKey() + " package which is required for " + StringUtil.join(value, ", ")); + } + message += StringUtil.join(dep, "\n"); + message += size == 1 ? " package" : " packages"; + message += "\n\nDo you want to proceed?"; + warning[0] = Messages.showYesNoDialog(message, "Warning", PlatformIconGroup.generalBalloonwarning()); + } + else { + String message = "You are attempting to uninstall packages which are required for another packages.\n\n"; + List dep = new ArrayList<>(); + for (Map.Entry> entry : dependentPackages.entrySet()) { + dep.add(entry.getKey() + " -> " + StringUtil.join(entry.getValue(), ", ")); + } + message += StringUtil.join(dep, "\n"); + message += "\n\nDo you want to proceed?"; + warning[0] = Messages.showYesNoDialog(message, "Warning", PlatformIconGroup.generalBalloonwarning()); + } + }, + application.getCurrentModalityState() + ); + } + if (warning[0] != Messages.YES) { + return true; } - }; - notificationRef.set(new PackagingNotification(PACKAGING_GROUP_ID, - getFailureTitle(), - firstLine + " Details...", - consulo.project.ui.notification.NotificationType.ERROR, - listener)); - } - } - ApplicationManager.getApplication().invokeLater(() -> { - if (myListener != null) { - myListener.finished(exceptions); } - final Notification notification = notificationRef.get(); - if (notification != null) { - notification.notify((Project)myProject); + catch (ExecutionException e) { + LOG.info("Error loading packages dependents: " + e.getMessage(), e); } - }); + return false; } - private static class PackagingNotification extends Notification { - - public PackagingNotification(@Nonnull NotificationGroup groupDisplayId, - @Nonnull String title, - @Nonnull String content, - @Nonnull consulo.project.ui.notification.NotificationType type, - @Nullable NotificationListener listener) { - super(groupDisplayId, title, content, type, listener); - } + private static Map> collectDependents(@Nonnull List packages, Sdk sdk) throws ExecutionException { + Map> dependentPackages = new HashMap<>(); + for (PyPackage pkg : packages) { + Set dependents = PyPackageManager.getInstance(sdk).getDependents(pkg); + if (!dependents.isEmpty()) { + for (PyPackage dependent : dependents) { + if (!packages.contains(dependent)) { + dependentPackages.put(pkg.getName(), dependents); + } + } + } + } + return dependentPackages; } - } - private static class InstallTask extends PackagingTask { - @Nonnull - private final List myRequirements; - @Nonnull - private final List myExtraArgs; - - public InstallTask(@Nullable Project project, - @Nonnull Sdk sdk, - @Nonnull List requirements, - @Nonnull List extraArgs, - @Nullable Listener listener) { - super(project, sdk, "Installing packages", listener); - myRequirements = requirements; - myExtraArgs = extraArgs; - } + private abstract static class PackagingTask extends Task.Backgroundable { + private static final NotificationGroup PACKAGING_GROUP_ID = NotificationGroup.balloonGroup("Python Packaging"); - @Nonnull - @Override - protected List runTask(@Nonnull ProgressIndicator indicator) { - final List exceptions = new ArrayList<>(); - final int size = myRequirements.size(); - final PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); - for (int i = 0; i < size; i++) { - final PyRequirement requirement = myRequirements.get(i); - indicator.setText(String.format("Installing package '%s'...", requirement)); - if (i == 0) { - indicator.setIndeterminate(true); + @Nonnull + protected final Sdk mySdk; + @Nullable + protected final Listener myListener; + + public PackagingTask(@Nullable Project project, @Nonnull Sdk sdk, @Nonnull String title, @Nullable Listener listener) { + super(project, title); + mySdk = sdk; + myListener = listener; } - else { - indicator.setIndeterminate(false); - indicator.setFraction((double)i / size); + + @Override + public void run(@Nonnull ProgressIndicator indicator) { + taskStarted(indicator); + taskFinished(runTask(indicator)); } - try { - manager.install(Collections.singletonList(requirement), myExtraArgs); + + @Nonnull + protected abstract List runTask(@Nonnull ProgressIndicator indicator); + + @Nonnull + protected abstract String getSuccessTitle(); + + @Nonnull + protected abstract String getSuccessDescription(); + + @Nonnull + protected abstract String getFailureTitle(); + + protected void taskStarted(@Nonnull ProgressIndicator indicator) { + Project project = (Project) getProject(); + PackagingNotification[] notifications = + NotificationsManager.getNotificationsManager().getNotificationsOfType(PackagingNotification.class, project); + for (PackagingNotification notification : notifications) { + notification.expire(); + } + indicator.setText(getTitle() + "..."); + if (myListener != null) { + project.getApplication().invokeLater(myListener::started); + } } - catch (RunCanceledByUserException e) { - exceptions.add(e); - break; + + protected void taskFinished(@Nonnull List exceptions) { + SimpleReference notificationRef = new SimpleReference<>(null); + if (exceptions.isEmpty()) { + notificationRef.set(new PackagingNotification( + PACKAGING_GROUP_ID, + getSuccessTitle(), + getSuccessDescription(), + NotificationType.INFORMATION, + null + )); + } + else { + PackageManagementService.ErrorDescription description = PyPackageManagementService.toErrorDescription(exceptions, mySdk); + if (description != null) { + String firstLine = getTitle() + ": error occurred."; + NotificationListener listener = (notification, event) -> { + assert myProject != null; + String title = StringUtil.capitalizeWords(getFailureTitle(), true); + consulo.ide.impl.idea.webcore.packaging.PackagesNotificationPanel.showError(title, description); + }; + notificationRef.set(new PackagingNotification( + PACKAGING_GROUP_ID, + getFailureTitle(), + firstLine + " Details...", + consulo.project.ui.notification.NotificationType.ERROR, + listener + )); + } + } + Project project = (Project) myProject; + project.getApplication().invokeLater(() -> { + if (myListener != null) { + myListener.finished(exceptions); + } + Notification notification = notificationRef.get(); + if (notification != null) { + notification.notify(project); + } + }); } - catch (ExecutionException e) { - exceptions.add(e); + + private static class PackagingNotification extends Notification { + + public PackagingNotification( + @Nonnull NotificationGroup groupDisplayId, + @Nonnull String title, + @Nonnull String content, + @Nonnull consulo.project.ui.notification.NotificationType type, + @Nullable NotificationListener listener + ) { + super(groupDisplayId, title, content, type, listener); + } } - } - manager.refresh(); - return exceptions; } - @Nonnull - @Override - protected String getSuccessTitle() { - return "Packages installed successfully"; - } + private static class InstallTask extends PackagingTask { + @Nonnull + private final List myRequirements; + @Nonnull + private final List myExtraArgs; + + public InstallTask( + @Nullable Project project, + @Nonnull Sdk sdk, + @Nonnull List requirements, + @Nonnull List extraArgs, + @Nullable Listener listener + ) { + super(project, sdk, "Installing packages", listener); + myRequirements = requirements; + myExtraArgs = extraArgs; + } - @Nonnull - @Override - protected String getSuccessDescription() { - return "Installed packages: " + PyPackageUtil.requirementsToString(myRequirements); - } + @Nonnull + @Override + protected List runTask(@Nonnull ProgressIndicator indicator) { + List exceptions = new ArrayList<>(); + int size = myRequirements.size(); + PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); + for (int i = 0; i < size; i++) { + PyRequirement requirement = myRequirements.get(i); + indicator.setText(String.format("Installing package '%s'...", requirement)); + if (i == 0) { + indicator.setIndeterminate(true); + } + else { + indicator.setIndeterminate(false); + indicator.setFraction((double) i / size); + } + try { + manager.install(Collections.singletonList(requirement), myExtraArgs); + } + catch (RunCanceledByUserException e) { + exceptions.add(e); + break; + } + catch (ExecutionException e) { + exceptions.add(e); + } + } + manager.refresh(); + return exceptions; + } - @Nonnull - @Override - protected String getFailureTitle() { - return "Install packages failed"; - } - } + @Nonnull + @Override + protected String getSuccessTitle() { + return "Packages installed successfully"; + } - private static class InstallManagementTask extends InstallTask { + @Nonnull + @Override + protected String getSuccessDescription() { + return "Installed packages: " + PyPackageUtil.requirementsToString(myRequirements); + } - public InstallManagementTask(@Nullable Project project, @Nonnull Sdk sdk, @Nullable Listener listener) { - super(project, sdk, Collections.emptyList(), Collections.emptyList(), listener); + @Nonnull + @Override + protected String getFailureTitle() { + return "Install packages failed"; + } } - @Nonnull - @Override - protected List runTask(@Nonnull ProgressIndicator indicator) { - final List exceptions = new ArrayList<>(); - final PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); - indicator.setText("Installing packaging tools..."); - indicator.setIndeterminate(true); - try { - manager.installManagement(); - } - catch (ExecutionException e) { - exceptions.add(e); - } - manager.refresh(); - return exceptions; - } + private static class InstallManagementTask extends InstallTask { - @Nonnull - @Override - protected String getSuccessDescription() { - return "Installed Python packaging tools"; - } - } + public InstallManagementTask(@Nullable Project project, @Nonnull Sdk sdk, @Nullable Listener listener) { + super(project, sdk, Collections.emptyList(), Collections.emptyList(), listener); + } - private static class UninstallTask extends PackagingTask { - @Nonnull - private final List myPackages; + @Nonnull + @Override + protected List runTask(@Nonnull ProgressIndicator indicator) { + List exceptions = new ArrayList<>(); + PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); + indicator.setText("Installing packaging tools..."); + indicator.setIndeterminate(true); + try { + manager.installManagement(); + } + catch (ExecutionException e) { + exceptions.add(e); + } + manager.refresh(); + return exceptions; + } - public UninstallTask(@Nullable Project project, @Nonnull Sdk sdk, @Nullable Listener listener, @Nonnull List packages) { - super(project, sdk, "Uninstalling packages", listener); - myPackages = packages; + @Nonnull + @Override + protected String getSuccessDescription() { + return "Installed Python packaging tools"; + } } - @Nonnull - @Override - protected List runTask(@Nonnull ProgressIndicator indicator) { - final PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); - indicator.setIndeterminate(true); - try { - manager.uninstall(myPackages); - return Collections.emptyList(); - } - catch (ExecutionException e) { - return Collections.singletonList(e); - } - finally { - manager.refresh(); - } - } + private static class UninstallTask extends PackagingTask { + @Nonnull + private final List myPackages; - @Nonnull - @Override - protected String getSuccessTitle() { - return "Packages uninstalled successfully"; - } + public UninstallTask(@Nullable Project project, @Nonnull Sdk sdk, @Nullable Listener listener, @Nonnull List packages) { + super(project, sdk, "Uninstalling packages", listener); + myPackages = packages; + } - @Nonnull - @Override - protected String getSuccessDescription() { - final String packagesString = StringUtil.join(myPackages, pkg -> "'" + pkg.getName() + "'", ", "); - return "Uninstalled packages: " + packagesString; - } + @Nonnull + @Override + protected List runTask(@Nonnull ProgressIndicator indicator) { + PyPackageManager manager = PyPackageManagers.getInstance().forSdk(mySdk); + indicator.setIndeterminate(true); + try { + manager.uninstall(myPackages); + return Collections.emptyList(); + } + catch (ExecutionException e) { + return Collections.singletonList(e); + } + finally { + manager.refresh(); + } + } - @Nonnull - @Override - protected String getFailureTitle() { - return "Uninstall packages failed"; + @Nonnull + @Override + protected String getSuccessTitle() { + return "Packages uninstalled successfully"; + } + + @Nonnull + @Override + protected String getSuccessDescription() { + String packagesString = StringUtil.join(myPackages, pkg -> "'" + pkg.getName() + "'", ", "); + return "Uninstalled packages: " + packagesString; + } + + @Nonnull + @Override + protected String getFailureTitle() { + return "Uninstall packages failed"; + } } - } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/PyUtil.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/PyUtil.java index 05e6e513..eb650334 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/PyUtil.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/PyUtil.java @@ -21,7 +21,6 @@ import com.jetbrains.python.PyTokenTypes; import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.impl.NotNullPredicate; -import com.jetbrains.python.impl.PyBundle; import com.jetbrains.python.impl.codeInsight.completion.OverwriteEqualsInsertHandler; import com.jetbrains.python.impl.codeInsight.dataflow.scope.ScopeUtil; import com.jetbrains.python.impl.codeInsight.stdlib.PyNamedTupleType; @@ -45,9 +44,9 @@ import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.resolve.RatedResolveResult; import com.jetbrains.python.psi.types.*; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; +import consulo.annotation.access.RequiredWriteAction; import consulo.application.Application; -import consulo.application.ApplicationManager; import consulo.application.progress.ProgressIndicator; import consulo.application.progress.ProgressManager; import consulo.application.progress.Task; @@ -78,13 +77,15 @@ import consulo.language.psi.util.QualifiedName; import consulo.language.scratch.ScratchFileService; import consulo.language.util.IncorrectOperationException; -import consulo.language.util.ModuleUtilCore; import consulo.module.Module; import consulo.module.ModuleManager; import consulo.module.content.ModuleRootManager; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.project.Project; import consulo.project.ui.wm.WindowManager; +import consulo.python.impl.localize.PyLocalize; import consulo.ui.NotificationType; +import consulo.ui.annotation.RequiredUIAccess; import consulo.ui.ex.RelativePoint; import consulo.ui.ex.popup.Balloon; import consulo.ui.ex.popup.JBPopupFactory; @@ -96,12 +97,11 @@ import consulo.util.lang.StringUtil; import consulo.virtualFileSystem.LocalFileSystem; import consulo.virtualFileSystem.VirtualFile; +import jakarta.annotation.Nonnull; +import jakarta.annotation.Nullable; import org.jetbrains.annotations.Contract; import org.jetbrains.annotations.Nls; -import org.jetbrains.annotations.NonNls; -import jakarta.annotation.Nonnull; -import jakarta.annotation.Nullable; import javax.swing.*; import java.awt.*; import java.io.File; @@ -115,1980 +115,2015 @@ import static com.jetbrains.python.psi.PyFunction.Modifier.STATICMETHOD; public class PyUtil { - private PyUtil() { - } - - @Nonnull - public static T[] getAllChildrenOfType(@Nonnull PsiElement element, @Nonnull Class aClass) { - List result = new SmartList<>(); - for (PsiElement child : element.getChildren()) { - if (instanceOf(child, aClass)) { - //noinspection unchecked - result.add((T)child); - } - else { - ContainerUtil.addAll(result, getAllChildrenOfType(child, aClass)); - } - } - return ArrayUtil.toObjectArray(result, aClass); - } - - /** - * @see PyUtil#flattenedParensAndTuples - */ - protected static List unfoldParentheses(PyExpression[] targets, - List receiver, - boolean unfoldListLiterals, - boolean unfoldStarExpressions) { - // NOTE: this proliferation of instanceofs is not very beautiful. Maybe rewrite using a visitor. - for (PyExpression exp : targets) { - if (exp instanceof PyParenthesizedExpression) { - final PyParenthesizedExpression parenExpr = (PyParenthesizedExpression)exp; - unfoldParentheses(new PyExpression[]{parenExpr.getContainedExpression()}, receiver, unfoldListLiterals, unfoldStarExpressions); - } - else if (exp instanceof PyTupleExpression) { - final PyTupleExpression tupleExpr = (PyTupleExpression)exp; - unfoldParentheses(tupleExpr.getElements(), receiver, unfoldListLiterals, unfoldStarExpressions); - } - else if (exp instanceof PyListLiteralExpression && unfoldListLiterals) { - final PyListLiteralExpression listLiteral = (PyListLiteralExpression)exp; - unfoldParentheses(listLiteral.getElements(), receiver, true, unfoldStarExpressions); - } - else if (exp instanceof PyStarExpression && unfoldStarExpressions) { - unfoldParentheses(new PyExpression[]{((PyStarExpression)exp).getExpression()}, receiver, unfoldListLiterals, true); - } - else if (exp != null) { - receiver.add(exp); - } - } - return receiver; - } - - /** - * Flattens the representation of every element in targets, and puts all results together. - * Elements of every tuple nested in target item are brought to the top level: (a, (b, (c, d))) -> (a, b, c, d) - * Typical usage: flattenedParensAndTuples(some_tuple.getExpressions()). - * - * @param targets target elements. - * @return the list of flattened expressions. - */ - @Nonnull - public static List flattenedParensAndTuples(PyExpression... targets) { - return unfoldParentheses(targets, new ArrayList<>(targets.length), false, false); - } - - @Nonnull - public static List flattenedParensAndLists(PyExpression... targets) { - return unfoldParentheses(targets, new ArrayList<>(targets.length), true, true); - } - - @Nonnull - public static List flattenedParensAndStars(PyExpression... targets) { - return unfoldParentheses(targets, new ArrayList<>(targets.length), false, true); - } - - // Poor man's filter - // TODO: move to a saner place - - public static boolean instanceOf(Object obj, Class... possibleClasses) { - if (obj == null || possibleClasses == null) { - return false; - } - for (Class cls : possibleClasses) { - if (cls.isInstance(obj)) { - return true; - } - } - return false; - } - - - /** - * Produce a reasonable representation of a PSI element, good for debugging. - * - * @param elt element to represent; nulls and invalid nodes are ok. - * @param cutAtEOL if true, representation stops at nearest EOL inside the element. - * @return the representation. - */ - @Nonnull - @NonNls - public static String getReadableRepr(PsiElement elt, final boolean cutAtEOL) { - if (elt == null) { - return "null!"; - } - ASTNode node = elt.getNode(); - if (node == null) { - return "null"; - } - else { - String s = node.getText(); - int cut_pos; - if (cutAtEOL) { - cut_pos = s.indexOf('\n'); - } - else { - cut_pos = -1; - } - if (cut_pos < 0) { - cut_pos = s.length(); - } - return s.substring(0, Math.min(cut_pos, s.length())); - } - } - - @Nullable - public static PyClass getContainingClassOrSelf(final PsiElement element) { - PsiElement current = element; - while (current != null && !(current instanceof PyClass)) { - current = current.getParent(); - } - return (PyClass)current; - } - - /** - * @param element for which to obtain the file - * @return PyFile, or null, if there's no containing file, or it is not a PyFile. - */ - @Nullable - public static PyFile getContainingPyFile(PyElement element) { - final PsiFile containingFile = element.getContainingFile(); - return containingFile instanceof PyFile ? (PyFile)containingFile : null; - } - - /** - * Shows an information balloon in a reasonable place at the top right of the window. - * - * @param project our project - * @param message the text, HTML markup allowed - * @param notificationType message type, changes the icon and the background. - */ - // TODO: move to a better place - public static void showBalloon(Project project, String message, NotificationType notificationType) { - // ripped from com.intellij.openapi.vcs.changes.ui.ChangesViewBalloonProblemNotifier - final JFrame frame = WindowManager.getInstance().getFrame(project.isDefault() ? null : project); - if (frame == null) { - return; - } - final JComponent component = frame.getRootPane(); - if (component == null) { - return; - } - final Rectangle rect = component.getVisibleRect(); - final Point p = new Point(rect.x + rect.width - 10, rect.y + 10); - final RelativePoint point = new RelativePoint(component, p); - - JBPopupFactory.getInstance() - .createHtmlTextBalloonBuilder(message, notificationType, null) - .setShowCallout(false) - .setCloseButtonEnabled(true) - .createBalloon() - .show(point, Balloon.Position.atLeft); - } - - @NonNls - /** - * Returns a quoted string representation, or "null". - */ - public static String nvl(Object s) { - if (s != null) { - return "'" + s.toString() + "'"; - } - else { - return "null"; - } - } - - /** - * Adds an item into a comma-separated list in a PSI tree. E.g. can turn "foo, bar" into "foo, bar, baz", adding commas as needed. - * - * @param parent the element to represent the list; we're adding a child to it. - * @param newItem the element we're inserting (the "baz" in the example). - * @param beforeThis node to mark the insertion point inside the list; must belong to a child of target. Set to null to add first element. - * @param isFirst true if we don't need a comma before the element we're adding. - * @param isLast true if we don't need a comma after the element we're adding. - */ - public static void addListNode(PsiElement parent, - PsiElement newItem, - ASTNode beforeThis, - boolean isFirst, - boolean isLast, - boolean addWhitespace) { - if (!FileModificationService.getInstance().preparePsiElementForWrite(parent)) { - return; - } - ASTNode node = parent.getNode(); - assert node != null; - ASTNode itemNode = newItem.getNode(); - assert itemNode != null; - Project project = parent.getProject(); - PyElementGenerator gen = PyElementGenerator.getInstance(project); - if (!isFirst) { - node.addChild(gen.createComma(), beforeThis); - } - node.addChild(itemNode, beforeThis); - if (!isLast) { - node.addChild(gen.createComma(), beforeThis); - } - if (addWhitespace) { - node.addChild(ASTFactory.whitespace(" "), beforeThis); - } - } - - /** - * Collects superclasses of a class all the way up the inheritance chain. The order is not necessarily the MRO. - */ - @Nonnull - public static List getAllSuperClasses(@Nonnull PyClass pyClass) { - List superClasses = new ArrayList<>(); - for (PyClass ancestor : pyClass.getAncestorClasses(null)) { - if (!PyNames.FAKE_OLD_BASE.equals(ancestor.getName())) { - superClasses.add(ancestor); - } - } - return superClasses; - } - - - // TODO: move to a more proper place? - - /** - * Determine the type of a special attribute. Currently supported: {@code __class__} and {@code __dict__}. - * - * @param ref reference to a possible attribute; only qualified references make sense. - * @return type, or null (if type cannot be determined, reference is not to a known attribute, etc.) - */ - @Nullable - public static PyType getSpecialAttributeType(@Nullable PyReferenceExpression ref, TypeEvalContext context) { - if (ref != null) { - PyExpression qualifier = ref.getQualifier(); - if (qualifier != null) { - String attr_name = ref.getReferencedName(); - if (PyNames.__CLASS__.equals(attr_name)) { - PyType qualifierType = context.getType(qualifier); - if (qualifierType instanceof PyClassType) { - return new PyClassTypeImpl(((PyClassType)qualifierType).getPyClass(), true); // always as class, never instance - } - } - else if (PyNames.DICT.equals(attr_name)) { - PyType qualifierType = context.getType(qualifier); - if (qualifierType instanceof PyClassType && ((PyClassType)qualifierType).isDefinition()) { - return PyBuiltinCache.getInstance(ref).getDictType(); - } - } - } - } - return null; - } - - /** - * Makes sure that 'thing' is not null; else throws an {@link IncorrectOperationException}. - * - * @param thing what we check. - * @return thing, if not null. - */ - @Nonnull - public static T sure(T thing) { - if (thing == null) { - throw new IncorrectOperationException(); - } - return thing; - } - - /** - * Makes sure that the 'thing' is true; else throws an {@link IncorrectOperationException}. - * - * @param thing what we check. - */ - public static void sure(boolean thing) { - if (!thing) { - throw new IncorrectOperationException(); - } - } - - public static boolean isAttribute(PyTargetExpression ex) { - return isInstanceAttribute(ex) || isClassAttribute(ex); - } - - public static boolean isInstanceAttribute(PyExpression target) { - if (!(target instanceof PyTargetExpression)) { - return false; - } - final ScopeOwner owner = ScopeUtil.getScopeOwner(target); - if (owner instanceof PyFunction) { - final PyFunction method = (PyFunction)owner; - if (method.getContainingClass() != null) { - if (method.getStub() != null) { - return true; - } - final PyParameter[] params = method.getParameterList().getParameters(); - if (params.length > 0) { - final PyTargetExpression targetExpr = (PyTargetExpression)target; - final PyExpression qualifier = targetExpr.getQualifier(); - return qualifier != null && qualifier.getText().equals(params[0].getName()); - } - } - } - return false; - } - - public static boolean isClassAttribute(PsiElement element) { - return element instanceof PyTargetExpression && ScopeUtil.getScopeOwner(element) instanceof PyClass; - } - - public static boolean isIfNameEqualsMain(PyIfStatement ifStatement) { - final PyExpression condition = ifStatement.getIfPart().getCondition(); - return isNameEqualsMain(condition); - } - - private static boolean isNameEqualsMain(PyExpression condition) { - if (condition instanceof PyParenthesizedExpression) { - return isNameEqualsMain(((PyParenthesizedExpression)condition).getContainedExpression()); - } - if (condition instanceof PyBinaryExpression) { - PyBinaryExpression binaryExpression = (PyBinaryExpression)condition; - if (binaryExpression.getOperator() == PyTokenTypes.OR_KEYWORD) { - return isNameEqualsMain(binaryExpression.getLeftExpression()) || isNameEqualsMain(binaryExpression.getRightExpression()); - } - final PyExpression rhs = binaryExpression.getRightExpression(); - return binaryExpression.getOperator() == PyTokenTypes.EQEQ && - binaryExpression.getLeftExpression().getText().equals(PyNames.NAME) && - rhs != null && rhs.getText().contains("__main__"); - } - return false; - } - - /** - * Searches for a method wrapping given element. - * - * @param start element presumably inside a method - * @param deep if true, allow 'start' to be inside functions nested in a method; else, 'start' must be directly inside a method. - * @return if not 'deep', [0] is the method and [1] is the class; if 'deep', first several elements may be the nested functions, - * the last but one is the method, and the last is the class. - */ - @Nullable - public static List searchForWrappingMethod(PsiElement start, boolean deep) { - PsiElement seeker = start; - List ret = new ArrayList<>(2); - while (seeker != null) { - PyFunction func = PsiTreeUtil.getParentOfType(seeker, PyFunction.class, true, PyClass.class); - if (func != null) { - PyClass cls = func.getContainingClass(); - if (cls != null) { - ret.add(func); - ret.add(cls); - return ret; - } - else if (deep) { - ret.add(func); - seeker = func; - } - else { - return null; // no immediate class - } - } - else { - return null; // no function - } - } - return null; - } - - public static boolean inSameFile(@Nonnull PsiElement e1, @Nonnull PsiElement e2) { - final PsiFile f1 = e1.getContainingFile(); - final PsiFile f2 = e2.getContainingFile(); - if (f1 == null || f2 == null) { - return false; - } - return f1 == f2; - } - - public static boolean onSameLine(@Nonnull PsiElement e1, @Nonnull PsiElement e2) { - final PsiDocumentManager documentManager = PsiDocumentManager.getInstance(e1.getProject()); - final Document document = documentManager.getDocument(e1.getContainingFile()); - if (document == null || document != documentManager.getDocument(e2.getContainingFile())) { - return false; - } - return document.getLineNumber(e1.getTextOffset()) == document.getLineNumber(e2.getTextOffset()); - } - - public static boolean isTopLevel(@Nonnull PsiElement element) { - if (element instanceof StubBasedPsiElement) { - final StubElement stub = ((StubBasedPsiElement)element).getStub(); - if (stub != null) { - final StubElement parentStub = stub.getParentStub(); - if (parentStub != null) { - return parentStub.getPsi() instanceof PsiFile; - } - } - } - return ScopeUtil.getScopeOwner(element) instanceof PsiFile; - } - - public static void deletePycFiles(String pyFilePath) { - if (pyFilePath.endsWith(PyNames.DOT_PY)) { - List filesToDelete = new ArrayList<>(); - File pyc = new File(pyFilePath + "c"); - if (pyc.exists()) { - filesToDelete.add(pyc); - } - File pyo = new File(pyFilePath + "o"); - if (pyo.exists()) { - filesToDelete.add(pyo); - } - final File file = new File(pyFilePath); - File pycache = new File(file.getParentFile(), PyNames.PYCACHE); - if (pycache.isDirectory()) { - final String shortName = FileUtil.getNameWithoutExtension(file); - Collections.addAll(filesToDelete, pycache.listFiles(pathname -> { - if (!FileUtil.extensionEquals(pathname.getName(), "pyc")) { - return false; - } - String nameWithMagic = FileUtil.getNameWithoutExtension(pathname); - return FileUtil.getNameWithoutExtension(nameWithMagic).equals(shortName); - })); - } - Application.get().getInstance(AsyncFileService.class).asyncDelete(filesToDelete); - } - } - - public static String getElementNameWithoutExtension(PsiNamedElement psiNamedElement) { - return psiNamedElement instanceof PyFile ? FileUtil.getNameWithoutExtension(psiNamedElement.getName()) : psiNamedElement.getName(); - } - - public static boolean hasUnresolvedAncestors(@Nonnull PyClass cls, @Nonnull TypeEvalContext context) { - for (PyClassLikeType type : cls.getAncestorTypes(context)) { - if (type == null) { - return true; - } - } - return false; - } - - @Nonnull - public static AccessDirection getPropertyAccessDirection(@Nonnull PyFunction function) { - final Property property = function.getProperty(); - if (property != null) { - if (property.getGetter().valueOrNull() == function) { - return AccessDirection.READ; - } - if (property.getSetter().valueOrNull() == function) { - return AccessDirection.WRITE; - } - else if (property.getDeleter().valueOrNull() == function) { - return AccessDirection.DELETE; - } - } - return AccessDirection.READ; - } - - public static void removeQualifier(@Nonnull final PyReferenceExpression element) { - final PyExpression qualifier = element.getQualifier(); - if (qualifier == null) { - return; - } - - if (qualifier instanceof PyCallExpression) { - final PyExpression callee = ((PyCallExpression)qualifier).getCallee(); - if (callee instanceof PyReferenceExpression) { - final PyExpression calleeQualifier = ((PyReferenceExpression)callee).getQualifier(); - if (calleeQualifier != null) { - qualifier.replace(calleeQualifier); - return; - } - } - } - final PsiElement dot = PyPsiUtils.getNextNonWhitespaceSibling(qualifier); - if (dot != null) { - dot.delete(); - } - qualifier.delete(); - } - - /** - * Returns string that represents element in string search. - * - * @param element element to search - * @return string that represents element - */ - @Nonnull - public static String computeElementNameForStringSearch(@Nonnull final PsiElement element) { - if (element instanceof PyFile) { - return FileUtil.getNameWithoutExtension(((PyFile)element).getName()); - } - if (element instanceof PsiDirectory) { - return ((PsiDirectory)element).getName(); - } - // Magic literals are always represented by their string values - if ((element instanceof PyStringLiteralExpression) && PyMagicLiteralTools.isMagicLiteral(element)) { - final String name = ((StringLiteralExpression)element).getStringValue(); - if (name != null) { - return name; - } - } - if (element instanceof PyElement) { - final String name = ((PyElement)element).getName(); - if (name != null) { - return name; - } - } - return element.getNode().getText(); - } - - public static boolean isOwnScopeComprehension(@Nonnull PyComprehensionElement comprehension) { - final boolean isAtLeast30 = LanguageLevel.forElement(comprehension).isAtLeast(LanguageLevel.PYTHON30); - final boolean isListComprehension = comprehension instanceof PyListCompExpression; - return !isListComprehension || isAtLeast30; - } - - public static boolean hasCustomDecorators(@Nonnull PyDecoratable decoratable) { - return PyKnownDecoratorUtil.hasNonBuiltinDecorator(decoratable, TypeEvalContext.codeInsightFallback(null)); - } - - public static boolean isDecoratedAsAbstract(@Nonnull final PyDecoratable decoratable) { - return PyKnownDecoratorUtil.hasAbstractDecorator(decoratable, TypeEvalContext.codeInsightFallback(null)); - } - - public static ASTNode createNewName(PyElement element, String name) { - return PyElementGenerator.getInstance(element.getProject()).createNameIdentifier(name, LanguageLevel.forElement(element)); - } - - /** - * Finds element declaration by resolving its references top the top but not further than file (to prevent un-stubbing) - * - * @param elementToResolve element to resolve - * @return its declaration - */ - @Nonnull - public static PsiElement resolveToTheTop(@Nonnull final PsiElement elementToResolve) { - PsiElement currentElement = elementToResolve; - final Set checkedElements = new HashSet<>(); // To prevent PY-20553 - while (true) { - final PsiReference reference = currentElement.getReference(); - if (reference == null) { - break; - } - final PsiElement resolve = reference.resolve(); - if ((resolve == null) || checkedElements.contains(resolve) || resolve.equals(currentElement) || !inSameFile(resolve, - currentElement)) { - break; - } - currentElement = resolve; - checkedElements.add(resolve); - } - return currentElement; - } - - /** - * Note that returned list may contain {@code null} items, e.g. for unresolved import elements, originally wrapped - * in {@link ImportedResolveResult}. - */ - @Nonnull - public static List multiResolveTopPriority(@Nonnull PsiElement element, @Nonnull PyResolveContext resolveContext) { - if (element instanceof PyReferenceOwner) { - final PsiPolyVariantReference ref = ((PyReferenceOwner)element).getReference(resolveContext); - return filterTopPriorityResults(ref.multiResolve(false)); - } - else { - final PsiReference reference = element.getReference(); - return reference != null ? Collections.singletonList(reference.resolve()) : Collections.emptyList(); - } - } - - @Nonnull - public static List multiResolveTopPriority(@Nonnull PsiPolyVariantReference reference) { - return filterTopPriorityResults(reference.multiResolve(false)); - } - - @Nonnull - public static List filterTopPriorityResults(@Nonnull ResolveResult[] resolveResults) { - if (resolveResults.length == 0) { - return Collections.emptyList(); - } - final List filtered = new ArrayList<>(); - final int maxRate = getMaxRate(resolveResults); - for (ResolveResult resolveResult : resolveResults) { - final int rate = resolveResult instanceof RatedResolveResult ? ((RatedResolveResult)resolveResult).getRate() : 0; - if (rate >= maxRate) { - final PsiElement element = resolveResult.getElement(); - if (element != null) { - filtered.add(element); - } - } - } - return filtered; - } - - private static int getMaxRate(@Nonnull ResolveResult[] resolveResults) { - int maxRate = Integer.MIN_VALUE; - for (ResolveResult resolveResult : resolveResults) { - if (resolveResult instanceof RatedResolveResult) { - final int rate = ((RatedResolveResult)resolveResult).getRate(); - if (rate > maxRate) { - maxRate = rate; - } - } - } - return maxRate; - } - - /** - * Gets class init method - * - * @param pyClass class where to find init - * @return class init method if any - */ - @Nullable - public static PyFunction getInitMethod(@Nonnull final PyClass pyClass) { - return pyClass.findMethodByName(PyNames.INIT, false, null); - } - - /** - * Returns Python language level for a virtual file. - * - * @see {@link LanguageLevel#forElement} - */ - @Nonnull - public static LanguageLevel getLanguageLevelForVirtualFile(@Nonnull Project project, @Nonnull VirtualFile virtualFile) { - if (virtualFile instanceof VirtualFileWindow) { - virtualFile = ((VirtualFileWindow)virtualFile).getDelegate(); - } - - // Most of the cases should be handled by this one, PyLanguageLevelPusher pushes folders only - final VirtualFile folder = virtualFile.getParent(); - if (folder != null) { - final LanguageLevel folderLevel = folder.getUserData(LanguageLevel.KEY); - if (folderLevel != null) { - return folderLevel; - } - final LanguageLevel fileLevel = PythonLanguageLevelPusher.getFileLanguageLevel(project, virtualFile); - if (fileLevel != null) { - return fileLevel; - } - } - else { - // However this allows us to setup language level per file manually - // in case when it is LightVirtualFile - final LanguageLevel level = virtualFile.getUserData(LanguageLevel.KEY); - if (level != null) { - return level; - } - - if (ApplicationManager.getApplication().isUnitTestMode()) { - final LanguageLevel languageLevel = LanguageLevel.FORCE_LANGUAGE_LEVEL; - if (languageLevel != null) { - return languageLevel; - } - } - } - return guessLanguageLevelWithCaching(project); - } - - public static void invalidateLanguageLevelCache(@Nonnull Project project) { - project.putUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL, null); - } - - @Nonnull - public static LanguageLevel guessLanguageLevelWithCaching(@Nonnull Project project) { - LanguageLevel languageLevel = project.getUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL); - if (languageLevel == null) { - languageLevel = guessLanguageLevel(project); - project.putUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL, languageLevel); - } - - return languageLevel; - } - - @Nonnull - public static LanguageLevel guessLanguageLevel(@Nonnull Project project) { - final ModuleManager moduleManager = ModuleManager.getInstance(project); - LanguageLevel maxLevel = null; - for (Module projectModule : moduleManager.getModules()) { - final Sdk sdk = PythonSdkType.findPythonSdk(projectModule); - if (sdk != null) { - final LanguageLevel level = PythonSdkType.getLanguageLevelForSdk(sdk); - if (maxLevel == null || maxLevel.isOlderThan(level)) { - maxLevel = level; - } - } - } - if (maxLevel != null) { - return maxLevel; - } - return LanguageLevel.getDefault(); - } - - /** - * Clone of C# "as" operator. - * Checks if expression has correct type and casts it if it has. Returns null otherwise. - * It saves coder from "instanceof / cast" chains. - * - * @param expression expression to check - * @param clazz class to cast - * @param class to cast - * @return expression casted to appropriate type (if could be casted). Null otherwise. - */ - @Nullable - @SuppressWarnings("unchecked") - public static T as(@Nullable final Object expression, @Nonnull final Class clazz) { - return ObjectUtil.tryCast(expression, clazz); - } - - // TODO: Move to PsiElement? - - /** - * Searches for references injected to element with certain type - * - * @param element element to search injected references for - * @param expectedClass expected type of element reference resolved to - * @param expected type of element reference resolved to - * @return resolved element if found or null if not found - */ - @Nullable - public static T findReference(@Nonnull final PsiElement element, @Nonnull final Class expectedClass) { - for (final PsiReference reference : element.getReferences()) { - final T result = as(reference.resolve(), expectedClass); - if (result != null) { - return result; - } - } - return null; - } - - - /** - * Converts collection to list of certain type - * - * @param expression expression of collection type - * @param elementClass expected element type - * @param expected element type - * @return list of elements of expected element type - */ - @Nonnull - public static List asList(@Nullable final Collection expression, @Nonnull final Class elementClass) { - if ((expression == null) || expression.isEmpty()) { - return Collections.emptyList(); - } - final List result = new ArrayList<>(); - for (final Object element : expression) { - final T toAdd = as(element, elementClass); - if (toAdd != null) { - result.add(toAdd); - } - } - return result; - } - - /** - * Force re-highlighting in all open editors that belong to specified project. - */ - public static void rehighlightOpenEditors(final @Nonnull Project project) { - ApplicationManager.getApplication().runWriteAction(() -> { - - for (Editor editor : EditorFactory.getInstance().getAllEditors()) { - if (editor instanceof EditorEx && editor.getProject() == project) { - final VirtualFile vFile = ((EditorEx)editor).getVirtualFile(); - if (vFile != null) { - final EditorHighlighter highlighter = EditorHighlighterFactory.getInstance().createEditorHighlighter(project, vFile); - ((EditorEx)editor).setHighlighter(highlighter); - } - } - } - }); - } - - public static T getParameterizedCachedValue(@Nonnull PsiElement element, @Nonnull P param, @Nonnull Function f) { - final Map cache = LanguageCachedValueUtil.getCachedValue(element, new CachedValueProvider>() { - @Nullable - @Override - public Result> compute() { - return Result.create(Maps.newHashMap(), PsiModificationTracker.MODIFICATION_COUNT); - } - }); - T result = cache.get(param); - if (result == null) { - result = f.apply(param); - cache.put(param, result); - } - return result; - } - - /** - * This method is allowed to be called from any thread, but in general you should not set {@code modal=true} if you're calling it - * from the write action, because in this case {@code function} will be executed right in the current thread (presumably EDT) - * without any progress whatsoever to avoid possible deadlock. - */ - public static void runWithProgress(@Nullable Project project, - @Nls(capitalization = Nls.Capitalization.Title) @Nonnull String title, - boolean modal, - boolean canBeCancelled, - @Nonnull final Consumer function) { - if (modal) { - ProgressManager.getInstance().run(new Task.Modal(project, title, canBeCancelled) { - @Override - public void run(@Nonnull ProgressIndicator indicator) { - function.accept(indicator); + private PyUtil() { + } + + @Nonnull + @RequiredReadAction + public static T[] getAllChildrenOfType(@Nonnull PsiElement element, @Nonnull Class aClass) { + List result = new SmartList<>(); + for (PsiElement child : element.getChildren()) { + if (instanceOf(child, aClass)) { + //noinspection unchecked + result.add((T) child); + } + else { + ContainerUtil.addAll(result, getAllChildrenOfType(child, aClass)); + } } - }); + return ArrayUtil.toObjectArray(result, aClass); } - else { - ProgressManager.getInstance().run(new Task.Backgroundable(project, title, canBeCancelled) { - @Override - public void run(@Nonnull ProgressIndicator indicator) { - function.accept(indicator); - } - }); - } - } - - /** - * Executes code only if

_PYCHARM_VERBOSE_MODE
is set in env (which should be done for debug purposes only) - * - * @param runnable code to call - */ - public static void verboseOnly(@Nonnull final Runnable runnable) { - if (System.getenv().get("_PYCHARM_VERBOSE_MODE") != null) { - runnable.run(); - } - } - - /** - * Returns the line comment that immediately precedes statement list of the given compound statement. Python parser ensures - * that it follows the statement header, i.e. it's directly after the colon, not on its own line. - */ - @Nullable - public static PsiComment getCommentOnHeaderLine(@Nonnull PyStatementListContainer container) { - final PyStatementList statementList = container.getStatementList(); - return as(PyPsiUtils.getPrevNonWhitespaceSibling(statementList), PsiComment.class); - } - - /** - * If argument is a PsiDirectory, turn it into a PsiFile that points to __init__.py in that directory. - * If there's no __init__.py there, null is returned, there's no point to resolve to a dir which is not a package. - * Alas, resolve() and multiResolve() can't return anything but a PyFile or PsiFileImpl.isPsiUpToDate() would fail. - * This is because isPsiUpToDate() relies on identity of objects returned by FileViewProvider.getPsi(). - * If we ever need to exactly tell a dir from __init__.py, that logic has to change. - * - * @param target a resolve candidate. - * @return a PsiFile if target was a PsiDirectory, or null, or target unchanged. - */ - @Nullable - public static PsiElement turnDirIntoInit(@Nullable PsiElement target) { - if (target instanceof PsiDirectory) { - final PsiDirectory dir = (PsiDirectory)target; - final PsiFile file = dir.findFile(PyNames.INIT_DOT_PY); - if (file != null) { - return file; // ResolveImportUtil will extract directory part as needed, everyone else are better off with a file. - } - else { - return null; - } // dir without __init__.py does not resolve - } - else { - return target; - } // don't touch non-dirs - } - - /** - * If directory is a PsiDirectory, that is also a valid Python package, return PsiFile that points to __init__.py, - * if such file exists, or directory itself (i.e. namespace package). Otherwise, return {@code null}. - * Unlike {@link #turnDirIntoInit(PsiElement)} this function handles namespace packages and - * accepts only PsiDirectories as target. - * - * @param directory directory to check - * @param anchor optional PSI element to determine language level as for {@link #isPackage(PsiDirectory, PsiElement)} - * @return PsiFile or PsiDirectory, if target is a Python package and {@code null} null otherwise - */ - @Nullable - public static PsiElement getPackageElement(@Nonnull PsiDirectory directory, @Nullable PsiElement anchor) { - if (isPackage(directory, anchor)) { - final PsiElement init = turnDirIntoInit(directory); - if (init != null) { - return init; - } - return directory; - } - return null; - } - - /** - * If target is a Python module named __init__.py file, return its directory. Otherwise return target unchanged. - * - * @param target PSI element to check - * @return PsiDirectory or target unchanged - */ - @Contract("null -> null; !null -> !null") - @Nullable - public static PsiElement turnInitIntoDir(@Nullable PsiElement target) { - if (target instanceof PyFile && isPackage((PsiFile)target)) { - return ((PsiFile)target).getContainingDirectory(); - } - return target; - } - - /** - * @see #isPackage(PsiDirectory, boolean, PsiElement) - */ - public static boolean isPackage(@Nonnull PsiDirectory directory, @Nullable PsiElement anchor) { - return isPackage(directory, true, anchor); - } - - /** - * Checks that given PsiDirectory can be treated as Python package, i.e. it's either contains __init__.py or it's a namespace package - * (effectively any directory in Python 3.3 and above). Setuptools namespace packages can be checked as well, but it requires access to - * {@link PySetuptoolsNamespaceIndex} and may slow things down during update of project indexes. - * Also note that this method does not check that directory itself and its parents have valid importable names, - * use {@link PyNames#isIdentifier(String)} for this purpose. - * - * @param directory PSI directory to check - * @param checkSetupToolsPackages whether setuptools namespace packages should be considered as well - * @param anchor optional anchor element to determine language level - * @return whether given directory is Python package - * @see PyNames#isIdentifier(String) - */ - public static boolean isPackage(@Nonnull PsiDirectory directory, boolean checkSetupToolsPackages, @Nullable PsiElement anchor) { - if (directory.findFile(PyNames.INIT_DOT_PY) != null) { - return true; - } - final LanguageLevel level = anchor != null ? LanguageLevel.forElement(anchor) : getLanguageLevelForVirtualFile(directory.getProject(), - directory.getVirtualFile()); - if (level.isAtLeast(LanguageLevel.PYTHON33)) { - return true; - } - return checkSetupToolsPackages && isSetuptoolsNamespacePackage(directory); - } - - public static boolean isPackage(@Nonnull PsiFile file) { - return PyNames.INIT_DOT_PY.equals(file.getName()); - } - - private static boolean isSetuptoolsNamespacePackage(@Nonnull PsiDirectory directory) { - final String packagePath = getPackagePath(directory); - return packagePath != null && !PySetuptoolsNamespaceIndex.find(packagePath, directory.getProject()).isEmpty(); - } - - @Nullable - private static String getPackagePath(@Nonnull PsiDirectory directory) { - final QualifiedName name = QualifiedNameFinder.findShortestImportableQName(directory); - return name != null ? name.toString() : null; - } - - /** - * Counts initial underscores of an identifier. - * - * @param name identifier - * @return 0 if no initial underscores found, 1 if there's only one underscore, 2 if there's two or more initial underscores. - */ - public static int getInitialUnderscores(String name) { - if (name == null) { - return 0; - } - int underscores = 0; - if (name.startsWith("__")) { - underscores = 2; - } - else if (name.startsWith("_")) { - underscores = 1; - } - return underscores; - } - - /** - * @param name - * @return true iff the name looks like a class-private one, starting with two underscores but not ending with two underscores. - */ - public static boolean isClassPrivateName(@Nonnull String name) { - return name.startsWith("__") && !name.endsWith("__"); - } - - public static boolean isSpecialName(@Nonnull String name) { - return name.length() > 4 && name.startsWith("__") && name.endsWith("__"); - } - - /** - * Constructs new lookup element for completion of keyword argument with equals sign appended. - * - * @param name name of the parameter - * @param project project instance to check code style settings and surround equals sign with spaces if necessary - * @return lookup element - */ - @Nonnull - public static LookupElement createNamedParameterLookup(@Nonnull String name, @Nullable Project project) { - final String suffix; - if (CodeStyleSettingsManager.getSettings(project).getCustomSettings(PyCodeStyleSettings.class).SPACE_AROUND_EQ_IN_KEYWORD_ARGUMENT) { - suffix = " = "; - } - else { - suffix = "="; - } - LookupElementBuilder lookupElementBuilder = LookupElementBuilder.create(name + suffix).withIcon(AllIcons.Nodes.Parameter); - lookupElementBuilder = lookupElementBuilder.withInsertHandler(OverwriteEqualsInsertHandler.INSTANCE); - return PrioritizedLookupElement.withGrouping(lookupElementBuilder, 1); - } - - /** - * Peels argument expression of parentheses and of keyword argument wrapper - * - * @param expr an item of getArguments() array - * @return expression actually passed as argument - */ - @Nullable - public static PyExpression peelArgument(PyExpression expr) { - while (expr instanceof PyParenthesizedExpression) { - expr = ((PyParenthesizedExpression)expr).getContainedExpression(); - } - if (expr instanceof PyKeywordArgument) { - expr = ((PyKeywordArgument)expr).getValueExpression(); - } - return expr; - } - - public static String getFirstParameterName(PyFunction container) { - String selfName = PyNames.CANONICAL_SELF; - if (container != null) { - final PyParameter[] params = container.getParameterList().getParameters(); - if (params.length > 0) { - final PyNamedParameter named = params[0].getAsNamed(); - if (named != null) { - selfName = named.getName(); - } - } - } - return selfName; - } - - /** - * @return Source roots and content roots for element's project - */ - @Nonnull - public static Collection getSourceRoots(@Nonnull PsiElement foothold) { - final Module module = ModuleUtilCore.findModuleForPsiElement(foothold); - if (module != null) { - return getSourceRoots(module); - } - return Collections.emptyList(); - } - - /** - * @return Source roots and content roots for module - */ - @Nonnull - public static Collection getSourceRoots(@Nonnull Module module) { - final Set result = new LinkedHashSet<>(); - final ModuleRootManager manager = ModuleRootManager.getInstance(module); - Collections.addAll(result, manager.getSourceRoots()); - Collections.addAll(result, manager.getContentRoots()); - return result; - } - - @Nullable - public static VirtualFile findInRoots(Module module, String path) { - if (module != null) { - for (VirtualFile root : getSourceRoots(module)) { - VirtualFile file = root.findFileByRelativePath(path); - if (file != null) { - return file; - } - } - } - return null; - } - - @Nullable - public static List getStringListFromTargetExpression(PyTargetExpression attr) { - return strListValue(attr.findAssignedValue()); - } - - @Nullable - public static List strListValue(PyExpression value) { - while (value instanceof PyParenthesizedExpression) { - value = ((PyParenthesizedExpression)value).getContainedExpression(); - } - if (value instanceof PySequenceExpression) { - final PyExpression[] elements = ((PySequenceExpression)value).getElements(); - List result = new ArrayList<>(elements.length); - for (PyExpression element : elements) { - if (!(element instanceof PyStringLiteralExpression)) { - return null; - } - result.add(((PyStringLiteralExpression)element).getStringValue()); - } - return result; - } - return null; - } - - @Nonnull - public static Map dictValue(@Nonnull PyDictLiteralExpression dict) { - Map result = Maps.newLinkedHashMap(); - for (PyKeyValueExpression keyValue : dict.getElements()) { - PyExpression key = keyValue.getKey(); - PyExpression value = keyValue.getValue(); - if (key instanceof PyStringLiteralExpression) { - result.put(((PyStringLiteralExpression)key).getStringValue(), value); - } - } - return result; - } - - /** - * @param what thing to search for - * @param variants things to search among - * @return true iff what.equals() one of the variants. - */ - public static boolean among(@Nonnull T what, T... variants) { - for (T s : variants) { - if (what.equals(s)) { - return true; - } - } - return false; - } - - @Nullable - public static String getKeywordArgumentString(PyCallExpression expr, String keyword) { - return PyPsiUtils.strValue(expr.getKeywordArgument(keyword)); - } - - public static boolean isExceptionClass(PyClass pyClass) { - if (isBaseException(pyClass.getQualifiedName())) { - return true; - } - for (PyClassLikeType type : pyClass.getAncestorTypes(TypeEvalContext.codeInsightFallback(pyClass.getProject()))) { - if (type != null && isBaseException(type.getClassQName())) { - return true; - } - } - return false; - } - - private static boolean isBaseException(String name) { - return name != null && (name.contains("BaseException") || name.startsWith("exceptions.")); - } - - public static class MethodFlags { - - private final boolean myIsStaticMethod; - private final boolean myIsMetaclassMethod; - private final boolean myIsSpecialMetaclassMethod; - private final boolean myIsClassMethod; /** - * @return true iff the method belongs to a metaclass (an ancestor of 'type'). + * @see PyUtil#flattenedParensAndTuples */ - public boolean isMetaclassMethod() { - return myIsMetaclassMethod; + protected static List unfoldParentheses( + PyExpression[] targets, + List receiver, + boolean unfoldListLiterals, + boolean unfoldStarExpressions + ) { + // NOTE: this proliferation of instanceofs is not very beautiful. Maybe rewrite using a visitor. + for (PyExpression exp : targets) { + if (exp instanceof PyParenthesizedExpression parenExpr) { + unfoldParentheses( + new PyExpression[]{parenExpr.getContainedExpression()}, + receiver, + unfoldListLiterals, + unfoldStarExpressions + ); + } + else if (exp instanceof PyTupleExpression tupleExpr) { + unfoldParentheses(tupleExpr.getElements(), receiver, unfoldListLiterals, unfoldStarExpressions); + } + else if (exp instanceof PyListLiteralExpression listLiteral && unfoldListLiterals) { + unfoldParentheses(listLiteral.getElements(), receiver, true, unfoldStarExpressions); + } + else if (exp instanceof PyStarExpression starExpr && unfoldStarExpressions) { + unfoldParentheses(new PyExpression[]{starExpr.getExpression()}, receiver, unfoldListLiterals, true); + } + else if (exp != null) { + receiver.add(exp); + } + } + return receiver; } /** - * @return iff isMetaclassMethod and the method is either __init__ or __call__. + * Flattens the representation of every element in targets, and puts all results together. + * Elements of every tuple nested in target item are brought to the top level: (a, (b, (c, d))) -> (a, b, c, d) + * Typical usage: flattenedParensAndTuples(some_tuple.getExpressions()). + * + * @param targets target elements. + * @return the list of flattened expressions. */ - public boolean isSpecialMetaclassMethod() { - return myIsSpecialMetaclassMethod; + @Nonnull + public static List flattenedParensAndTuples(PyExpression... targets) { + return unfoldParentheses(targets, new ArrayList<>(targets.length), false, false); } - public boolean isStaticMethod() { - return myIsStaticMethod; + @Nonnull + public static List flattenedParensAndLists(PyExpression... targets) { + return unfoldParentheses(targets, new ArrayList<>(targets.length), true, true); } - public boolean isClassMethod() { - return myIsClassMethod; + @Nonnull + public static List flattenedParensAndStars(PyExpression... targets) { + return unfoldParentheses(targets, new ArrayList<>(targets.length), false, true); } - private MethodFlags(boolean isClassMethod, boolean isStaticMethod, boolean isMetaclassMethod, boolean isSpecialMetaclassMethod) { - myIsClassMethod = isClassMethod; - myIsStaticMethod = isStaticMethod; - myIsMetaclassMethod = isMetaclassMethod; - myIsSpecialMetaclassMethod = isSpecialMetaclassMethod; + // Poor man's filter + // TODO: move to a saner place + + public static boolean instanceOf(Object obj, Class... possibleClasses) { + if (obj == null || possibleClasses == null) { + return false; + } + for (Class cls : possibleClasses) { + if (cls.isInstance(obj)) { + return true; + } + } + return false; } + /** - * @param node a function - * @return a new flags object, or null if the function is not a method + * Produce a reasonable representation of a PSI element, good for debugging. + * + * @param elt element to represent; nulls and invalid nodes are ok. + * @param cutAtEOL if true, representation stops at nearest EOL inside the element. + * @return the representation. */ - @Nullable - public static MethodFlags of(@Nonnull PyFunction node) { - PyClass cls = node.getContainingClass(); - if (cls != null) { - PyFunction.Modifier modifier = node.getModifier(); - boolean isMetaclassMethod = false; - PyClass type_cls = PyBuiltinCache.getInstance(node).getClass("type"); - for (PyClass ancestor_cls : cls.getAncestorClasses(null)) { - if (ancestor_cls == type_cls) { - isMetaclassMethod = true; - break; - } - } - final String method_name = node.getName(); - boolean isSpecialMetaclassMethod = isMetaclassMethod && method_name != null && among(method_name, PyNames.INIT, "__call__"); - return new MethodFlags(modifier == CLASSMETHOD, modifier == STATICMETHOD, isMetaclassMethod, isSpecialMetaclassMethod); - } - return null; - } - - //TODO: Doc - public boolean isInstanceMethod() { - return !(myIsClassMethod || myIsStaticMethod); - } - } - - public static boolean isSuperCall(@Nonnull PyCallExpression node) { - PyClass klass = PsiTreeUtil.getParentOfType(node, PyClass.class); - if (klass == null) { - return false; - } - PyExpression callee = node.getCallee(); - if (callee == null) { - return false; - } - String name = callee.getName(); - if (PyNames.SUPER.equals(name)) { - PsiReference reference = callee.getReference(); - if (reference == null) { - return false; - } - PsiElement resolved = reference.resolve(); - PyBuiltinCache cache = PyBuiltinCache.getInstance(node); - if (resolved != null && cache.isBuiltin(resolved)) { - PyExpression[] args = node.getArguments(); - if (args.length > 0) { - String firstArg = args[0].getText(); - if (firstArg.equals(klass.getName()) || firstArg.equals(PyNames.CANONICAL_SELF + "." + PyNames.__CLASS__)) { - return true; - } - for (PyClass s : klass.getAncestorClasses(null)) { - if (firstArg.equals(s.getName())) { - return true; - } - } + @Nonnull + public static String getReadableRepr(PsiElement elt, boolean cutAtEOL) { + if (elt == null) { + return "null!"; + } + ASTNode node = elt.getNode(); + if (node == null) { + return "null"; } else { - return true; - } - } - } - return false; - } - - @Nonnull - public static PyFile getOrCreateFile(String path, Project project) { - final VirtualFile vfile = LocalFileSystem.getInstance().findFileByIoFile(new File(path)); - final PsiFile psi; - if (vfile == null) { - final File file = new File(path); - try { - final VirtualFile baseDir = project.getBaseDir(); - final FileTemplateManager fileTemplateManager = FileTemplateManager.getInstance(project); - final FileTemplate template = fileTemplateManager.getInternalTemplate("Python Script"); - final Properties properties = fileTemplateManager.getDefaultProperties(); - properties.setProperty("NAME", FileUtil.getNameWithoutExtension(file.getName())); - final String content = (template != null) ? template.getText(properties) : null; - psi = PyExtractSuperclassHelper.placeFile(project, - StringUtil.notNullize(file.getParent(), baseDir != null ? baseDir.getPath() : "."), - file.getName(), - content); - } - catch (IOException e) { - throw new IncorrectOperationException(String.format("Cannot create file '%s'", path), e); - } - } - else { - psi = PsiManager.getInstance(project).findFile(vfile); - } - if (!(psi instanceof PyFile)) { - throw new IncorrectOperationException(PyBundle.message( - "refactoring.move.module.members.error.cannot.place.elements.into.nonpython.file")); - } - return (PyFile)psi; - } - - /** - * counts elements in iterable - * - * @param expression to count containing elements (iterable) - * @return element count - */ - public static int getElementsCount(PyExpression expression, TypeEvalContext evalContext) { - int valuesLength = -1; - PyType type = evalContext.getType(expression); - if (type instanceof PyTupleType) { - valuesLength = ((PyTupleType)type).getElementCount(); - } - else if (type instanceof PyNamedTupleType) { - valuesLength = ((PyNamedTupleType)type).getElementCount(); - } - else if (expression instanceof PySequenceExpression) { - valuesLength = ((PySequenceExpression)expression).getElements().length; - } - else if (expression instanceof PyStringLiteralExpression) { - valuesLength = ((PyStringLiteralExpression)expression).getStringValue().length(); - } - else if (expression instanceof PyNumericLiteralExpression) { - valuesLength = 1; - } - else if (expression instanceof PyCallExpression) { - PyCallExpression call = (PyCallExpression)expression; - if (call.isCalleeText("dict")) { - valuesLength = call.getArguments().length; - } - else if (call.isCalleeText("tuple")) { - PyExpression[] arguments = call.getArguments(); - if (arguments.length > 0 && arguments[0] instanceof PySequenceExpression) { - valuesLength = ((PySequenceExpression)arguments[0]).getElements().length; - } - } - } - return valuesLength; - } - - @Nullable - public static PsiElement findPrevAtOffset(PsiFile psiFile, int caretOffset, Class... toSkip) { - PsiElement element; - if (caretOffset < 0) { - return null; - } - int lineStartOffset = 0; - final Document document = PsiDocumentManager.getInstance(psiFile.getProject()).getDocument(psiFile); - if (document != null) { - int lineNumber = document.getLineNumber(caretOffset); - lineStartOffset = document.getLineStartOffset(lineNumber); - } - do { - caretOffset--; - element = psiFile.findElementAt(caretOffset); - } - while (caretOffset >= lineStartOffset && instanceOf(element, toSkip)); - return instanceOf(element, toSkip) ? null : element; - } - - @Nullable - public static PsiElement findNonWhitespaceAtOffset(PsiFile psiFile, int caretOffset) { - PsiElement element = findNextAtOffset(psiFile, caretOffset, PsiWhiteSpace.class); - if (element == null) { - element = findPrevAtOffset(psiFile, caretOffset - 1, PsiWhiteSpace.class); - } - return element; - } - - @Nullable - public static PsiElement findElementAtOffset(PsiFile psiFile, int caretOffset) { - PsiElement element = findPrevAtOffset(psiFile, caretOffset); - if (element == null) { - element = findNextAtOffset(psiFile, caretOffset); - } - return element; - } - - @Nullable - public static PsiElement findNextAtOffset(@Nonnull final PsiFile psiFile, int caretOffset, Class... toSkip) { - PsiElement element = psiFile.findElementAt(caretOffset); - if (element == null) { - return null; - } - - final Document document = PsiDocumentManager.getInstance(psiFile.getProject()).getDocument(psiFile); - int lineEndOffset = 0; - if (document != null) { - int lineNumber = document.getLineNumber(caretOffset); - lineEndOffset = document.getLineEndOffset(lineNumber); - } - while (caretOffset < lineEndOffset && instanceOf(element, toSkip)) { - caretOffset++; - element = psiFile.findElementAt(caretOffset); - } - return instanceOf(element, toSkip) ? null : element; - } - - /** - * Adds element to statement list to the correct place according to its dependencies. - * - * @param element to insert - * @param statementList where element should be inserted - * @return inserted element - */ - public static T addElementToStatementList(@Nonnull final T element, @Nonnull final PyStatementList statementList) { - PsiElement before = null; - PsiElement after = null; - for (final PyStatement statement : statementList.getStatements()) { - if (PyDependenciesComparator.depends(element, statement)) { - after = statement; - } - else if (PyDependenciesComparator.depends(statement, element)) { - before = statement; - } - } - final PsiElement result; - if (after != null) { - - result = statementList.addAfter(element, after); - } - else if (before != null) { - result = statementList.addBefore(element, before); - } - else { - result = addElementToStatementList(element, statementList, true); - } - @SuppressWarnings("unchecked") // Inserted element can't have different type - final T resultCasted = (T)result; - return resultCasted; - } - - - /** - * Inserts specified element into the statement list either at the beginning or at its end. If new element is going to be - * inserted at the beginning, any preceding docstrings and/or calls to super methods will be skipped. - * Moreover if statement list previously didn't contain any statements, explicit new line and indentation will be inserted in - * front of it. - * - * @param element element to insert - * @param statementList statement list - * @param toTheBeginning whether to insert element at the beginning or at the end of the statement list - * @return actually inserted element as for {@link PsiElement#add(PsiElement)} - */ - @Nonnull - public static PsiElement addElementToStatementList(@Nonnull PsiElement element, - @Nonnull PyStatementList statementList, - boolean toTheBeginning) { - final PsiElement prevElem = PyPsiUtils.getPrevNonWhitespaceSibling(statementList); - // If statement list is on the same line as previous element (supposedly colon), move its only statement on the next line - if (prevElem != null && onSameLine(statementList, prevElem)) { - final PsiDocumentManager manager = PsiDocumentManager.getInstance(statementList.getProject()); - final Document document = manager.getDocument(statementList.getContainingFile()); - if (document != null) { - final PyStatementListContainer container = (PyStatementListContainer)statementList.getParent(); - manager.doPostponedOperationsAndUnblockDocument(document); - final String indentation = "\n" + PyIndentUtil.getElementIndent(statementList); - // If statement list was empty initially, we need to add some anchor statement ("pass"), so that preceding new line was not - // parsed as following entire StatementListContainer (e.g. function). It's going to be replaced anyway. - final String text = statementList.getStatements().length == 0 ? indentation + PyNames.PASS : indentation; - document.insertString(statementList.getTextRange().getStartOffset(), text); - manager.commitDocument(document); - statementList = container.getStatementList(); - } - } - final PsiElement firstChild = statementList.getFirstChild(); - if (firstChild == statementList.getLastChild() && firstChild instanceof PyPassStatement) { - element = firstChild.replace(element); - } - else { - final PyStatement[] statements = statementList.getStatements(); - if (toTheBeginning && statements.length > 0) { - final PyDocStringOwner docStringOwner = PsiTreeUtil.getParentOfType(statementList, PyDocStringOwner.class); - PyStatement anchor = statements[0]; - if (docStringOwner != null && anchor instanceof PyExpressionStatement && - ((PyExpressionStatement)anchor).getExpression() == docStringOwner.getDocStringExpression()) { - final PyStatement next = PsiTreeUtil.getNextSiblingOfType(anchor, PyStatement.class); - if (next == null) { - return statementList.addAfter(element, anchor); - } - anchor = next; - } - while (anchor instanceof PyExpressionStatement) { - final PyExpression expression = ((PyExpressionStatement)anchor).getExpression(); - if (expression instanceof PyCallExpression) { - final PyExpression callee = ((PyCallExpression)expression).getCallee(); - if ((isSuperCall((PyCallExpression)expression) || (callee != null && PyNames.INIT.equals(callee.getName())))) { - final PyStatement next = PsiTreeUtil.getNextSiblingOfType(anchor, PyStatement.class); - if (next == null) { - return statementList.addAfter(element, anchor); - } - anchor = next; - continue; - } - } - break; - } - element = statementList.addBefore(element, anchor); - } - else { - element = statementList.add(element); - } - } - return element; - } - - @Nonnull - public static List> getOverloadedParametersSet(@Nonnull PyCallable callable, @Nonnull TypeEvalContext context) { - final List> parametersSet = getOverloadedParametersSet(context.getType(callable), context); - return parametersSet != null ? parametersSet : Collections.singletonList(Arrays.asList(callable.getParameterList().getParameters())); - } - - @Nullable - private static List getParametersOfCallableType(@Nonnull PyCallableType type, @Nonnull TypeEvalContext context) { - final List callableTypeParameters = type.getParameters(context); - if (callableTypeParameters != null) { - boolean allParametersDefined = true; - final List parameters = new ArrayList<>(); - for (PyCallableParameter callableParameter : callableTypeParameters) { - final PyParameter parameter = callableParameter.getParameter(); - if (parameter == null) { - allParametersDefined = false; - break; - } - parameters.add(parameter); - } - if (allParametersDefined) { - return parameters; - } - } - return null; - } - - @Nullable - private static List> getOverloadedParametersSet(@Nullable PyType type, @Nonnull TypeEvalContext context) { - if (type instanceof PyUnionType) { - type = ((PyUnionType)type).excludeNull(context); - } - - if (type instanceof PyCallableType) { - final List results = getParametersOfCallableType((PyCallableType)type, context); - if (results != null) { - return Collections.singletonList(results); - } - } - else if (type instanceof PyUnionType) { - final List> results = new ArrayList<>(); - final Collection members = ((PyUnionType)type).getMembers(); - for (PyType member : members) { - if (member instanceof PyCallableType) { - final List parameters = getParametersOfCallableType((PyCallableType)member, context); - if (parameters != null) { - results.add(parameters); - } - } - } - if (!results.isEmpty()) { - return results; - } - } - - return null; - } - - @Nonnull - public static List getParameters(@Nonnull PyCallable callable, @Nonnull TypeEvalContext context) { - final List> parametersSet = getOverloadedParametersSet(callable, context); - assert !parametersSet.isEmpty(); - return parametersSet.get(0); - } - - public static boolean isSignatureCompatibleTo(@Nonnull PyCallable callable, - @Nonnull PyCallable otherCallable, - @Nonnull TypeEvalContext context) { - final List parameters = getParameters(callable, context); - final List otherParameters = getParameters(otherCallable, context); - final int optionalCount = optionalParametersCount(parameters); - final int otherOptionalCount = optionalParametersCount(otherParameters); - final int requiredCount = requiredParametersCount(callable, parameters); - final int otherRequiredCount = requiredParametersCount(otherCallable, otherParameters); - if (hasPositionalContainer(otherParameters) || hasKeywordContainer(otherParameters)) { - if (otherParameters.size() == specialParametersCount(otherCallable, otherParameters)) { - return true; - } - } - if (hasPositionalContainer(parameters) || hasKeywordContainer(parameters)) { - return requiredCount <= otherRequiredCount; - } - return requiredCount <= otherRequiredCount && parameters.size() >= otherParameters.size() && optionalCount >= otherOptionalCount; - } - - private static int optionalParametersCount(@Nonnull List parameters) { - int n = 0; - for (PyParameter parameter : parameters) { - if (parameter.hasDefaultValue()) { - n++; - } - } - return n; - } - - private static int requiredParametersCount(@Nonnull PyCallable callable, @Nonnull List parameters) { - return parameters.size() - optionalParametersCount(parameters) - specialParametersCount(callable, parameters); - } - - private static int specialParametersCount(@Nonnull PyCallable callable, @Nonnull List parameters) { - int n = 0; - if (hasPositionalContainer(parameters)) { - n++; - } - if (hasKeywordContainer(parameters)) { - n++; - } - if (callable.asMethod() != null) { - n++; - } - else { - if (parameters.size() > 0) { - final PyParameter first = parameters.get(0); - if (PyNames.CANONICAL_SELF.equals(first.getName())) { - n++; - } - } - } - return n; - } - - private static boolean hasPositionalContainer(@Nonnull List parameters) { - for (PyParameter parameter : parameters) { - if (parameter instanceof PyNamedParameter && ((PyNamedParameter)parameter).isPositionalContainer()) { - return true; - } - } - return false; - } - - private static boolean hasKeywordContainer(@Nonnull List parameters) { - for (PyParameter parameter : parameters) { - if (parameter instanceof PyNamedParameter && ((PyNamedParameter)parameter).isKeywordContainer()) { - return true; - } - } - return false; - } - - public static boolean isInit(@Nonnull final PyFunction function) { - return PyNames.INIT.equals(function.getName()); - } - - /** - * Filters out {@link PyMemberInfo} - * that should not be displayed in this refactoring (like object) - * - * @param pyMemberInfos collection to sort - * @return sorted collection - */ - @Nonnull - public static Collection> filterOutObject(@Nonnull final Collection> pyMemberInfos) { - return Collections2.filter(pyMemberInfos, new ObjectPredicate(false)); - } - - public static boolean isStarImportableFrom(@Nonnull String name, @Nonnull PyFile file) { - final List dunderAll = file.getDunderAll(); - return dunderAll != null ? dunderAll.contains(name) : !name.startsWith("_"); - } - - /** - * Filters only PyClass object (new class) - */ - public static class ObjectPredicate extends NotNullPredicate> { - private final boolean myAllowObjects; + String s = node.getText(); + int cut_pos; + if (cutAtEOL) { + cut_pos = s.indexOf('\n'); + } + else { + cut_pos = -1; + } + if (cut_pos < 0) { + cut_pos = s.length(); + } + return s.substring(0, Math.min(cut_pos, s.length())); + } + } + + @Nullable + public static PyClass getContainingClassOrSelf(PsiElement element) { + PsiElement current = element; + while (current != null && !(current instanceof PyClass)) { + current = current.getParent(); + } + return (PyClass) current; + } /** - * @param allowObjects allows only objects if true. Allows all but objects otherwise. + * @param element for which to obtain the file + * @return PyFile, or null, if there's no containing file, or it is not a PyFile. */ - public ObjectPredicate(final boolean allowObjects) { - myAllowObjects = allowObjects; + @Nullable + public static PyFile getContainingPyFile(PyElement element) { + return element.getContainingFile() instanceof PyFile containingFile ? containingFile : null; } - @Override - public boolean applyNotNull(@Nonnull final PyMemberInfo input) { - return myAllowObjects == isObject(input); + /** + * Shows an information balloon in a reasonable place at the top right of the window. + * + * @param project our project + * @param message the text, HTML markup allowed + * @param notificationType message type, changes the icon and the background. + */ + // TODO: move to a better place + public static void showBalloon(Project project, String message, NotificationType notificationType) { + // ripped from com.intellij.openapi.vcs.changes.ui.ChangesViewBalloonProblemNotifier + JFrame frame = WindowManager.getInstance().getFrame(project.isDefault() ? null : project); + if (frame == null) { + return; + } + JComponent component = frame.getRootPane(); + if (component == null) { + return; + } + Rectangle rect = component.getVisibleRect(); + Point p = new Point(rect.x + rect.width - 10, rect.y + 10); + RelativePoint point = new RelativePoint(component, p); + + JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder(message, notificationType, null) + .setShowCallout(false) + .setCloseButtonEnabled(true) + .createBalloon() + .show(point, Balloon.Position.atLeft); } - private static boolean isObject(@Nonnull final PyMemberInfo classMemberInfo) { - final PyElement element = classMemberInfo.getMember(); - return (element instanceof PyClass) && PyNames.OBJECT.equals(element.getName()); + /** + * Returns a quoted string representation, or "null". + */ + public static String nvl(Object s) { + if (s != null) { + return "'" + s.toString() + "'"; + } + else { + return "null"; + } } - } - /** - * Sometimes you do not know real FQN of some class, but you know class name and its package. - * I.e. django.apps.conf.AppConfig is not documented, but you know - * AppConfig and django package. - * - * @param symbol element to check (class or function) - * @param expectedPackage package like "django" - * @param expectedName expected name (i.e. AppConfig) - * @return true if element in package - */ - public static boolean isSymbolInPackage(@Nonnull final PyQualifiedNameOwner symbol, - @Nonnull final String expectedPackage, - @Nonnull final String expectedName) { - final String qualifiedNameString = symbol.getQualifiedName(); - if (qualifiedNameString == null) { - return false; + /** + * Adds an item into a comma-separated list in a PSI tree. E.g. can turn "foo, bar" into "foo, bar, baz", adding commas as needed. + * + * @param parent the element to represent the list; we're adding a child to it. + * @param newItem the element we're inserting (the "baz" in the example). + * @param beforeThis node to mark the insertion point inside the list; must belong to a child of target. Set to null to add first element. + * @param isFirst true if we don't need a comma before the element we're adding. + * @param isLast true if we don't need a comma after the element we're adding. + */ + public static void addListNode( + PsiElement parent, + PsiElement newItem, + ASTNode beforeThis, + boolean isFirst, + boolean isLast, + boolean addWhitespace + ) { + if (!FileModificationService.getInstance().preparePsiElementForWrite(parent)) { + return; + } + ASTNode node = parent.getNode(); + assert node != null; + ASTNode itemNode = newItem.getNode(); + assert itemNode != null; + Project project = parent.getProject(); + PyElementGenerator gen = PyElementGenerator.getInstance(project); + if (!isFirst) { + node.addChild(gen.createComma(), beforeThis); + } + node.addChild(itemNode, beforeThis); + if (!isLast) { + node.addChild(gen.createComma(), beforeThis); + } + if (addWhitespace) { + node.addChild(ASTFactory.whitespace(" "), beforeThis); + } } - final QualifiedName qualifiedName = QualifiedName.fromDottedString(qualifiedNameString); - final String aPackage = qualifiedName.getFirstComponent(); - if (!(expectedPackage.equals(aPackage))) { - return false; + + /** + * Collects superclasses of a class all the way up the inheritance chain. The order is not necessarily the MRO. + */ + @Nonnull + @RequiredReadAction + public static List getAllSuperClasses(@Nonnull PyClass pyClass) { + List superClasses = new ArrayList<>(); + for (PyClass ancestor : pyClass.getAncestorClasses(null)) { + if (!PyNames.FAKE_OLD_BASE.equals(ancestor.getName())) { + superClasses.add(ancestor); + } + } + return superClasses; } - final String symbolName = qualifiedName.getLastComponent(); - return expectedName.equals(symbolName); - } - /** - * Checks that given class is the root of class hierarchy, i.e. it's either {@code object} or - * special {@link PyNames#FAKE_OLD_BASE} class for old-style classes. - * - * @param cls Python class to check - * @see PyBuiltinCache - * @see PyNames#FAKE_OLD_BASE - */ - public static boolean isObjectClass(@Nonnull PyClass cls) { - final PyBuiltinCache builtinCache = PyBuiltinCache.getInstance(cls); - return cls == builtinCache.getClass(PyNames.OBJECT) || cls == builtinCache.getClass(PyNames.FAKE_OLD_BASE); - } + // TODO: move to a more proper place? - public static boolean isInScratchFile(@Nonnull PsiElement element) { - return ScratchFileService.isInScratchRoot(PsiUtilCore.getVirtualFile(element)); - } + /** + * Determine the type of a special attribute. Currently supported: {@code __class__} and {@code __dict__}. + * + * @param ref reference to a possible attribute; only qualified references make sense. + * @return type, or null (if type cannot be determined, reference is not to a known attribute, etc.) + */ + @Nullable + public static PyType getSpecialAttributeType(@Nullable PyReferenceExpression ref, TypeEvalContext context) { + if (ref != null) { + PyExpression qualifier = ref.getQualifier(); + if (qualifier != null) { + String attr_name = ref.getReferencedName(); + if (PyNames.__CLASS__.equals(attr_name)) { + PyType qualifierType = context.getType(qualifier); + if (qualifierType instanceof PyClassType classType) { + return new PyClassTypeImpl(classType.getPyClass(), true); // always as class, never instance + } + } + else if (PyNames.DICT.equals(attr_name)) { + PyType qualifierType = context.getType(qualifier); + if (qualifierType instanceof PyClassType classType && classType.isDefinition()) { + return PyBuiltinCache.getInstance(ref).getDictType(); + } + } + } + } + return null; + } - @Nullable - public static PyType getReturnTypeOfMember(@Nonnull PyType type, - @Nonnull String memberName, - @Nullable PyExpression location, - @Nonnull TypeEvalContext context) { - final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); - final List resolveResults = - type.resolveMember(memberName, location, AccessDirection.READ, resolveContext); + /** + * Makes sure that 'thing' is not null; else throws an {@link IncorrectOperationException}. + * + * @param thing what we check. + * @return thing, if not null. + */ + @Nonnull + public static T sure(T thing) { + if (thing == null) { + throw new IncorrectOperationException(); + } + return thing; + } - if (resolveResults != null) { - final List types = new ArrayList<>(); + /** + * Makes sure that the 'thing' is true; else throws an {@link IncorrectOperationException}. + * + * @param thing what we check. + */ + public static void sure(boolean thing) { + if (!thing) { + throw new IncorrectOperationException(); + } + } - for (RatedResolveResult resolveResult : resolveResults) { - final PyType returnType = getReturnType(resolveResult.getElement(), context); - - if (returnType != null) { - types.add(returnType); - } - } + @RequiredReadAction + public static boolean isAttribute(PyTargetExpression ex) { + return isInstanceAttribute(ex) || isClassAttribute(ex); + } - return PyUnionType.union(types); - } + @RequiredReadAction + public static boolean isInstanceAttribute(PyExpression target) { + if (!(target instanceof PyTargetExpression targetExpr)) { + return false; + } + ScopeOwner owner = ScopeUtil.getScopeOwner(target); + if (owner instanceof PyFunction method && method.getContainingClass() != null) { + if (method.getStub() != null) { + return true; + } + PyParameter[] params = method.getParameterList().getParameters(); + if (params.length > 0) { + PyExpression qualifier = targetExpr.getQualifier(); + return qualifier != null && qualifier.getText().equals(params[0].getName()); + } + } + return false; + } - return null; - } + public static boolean isClassAttribute(PsiElement element) { + return element instanceof PyTargetExpression && ScopeUtil.getScopeOwner(element) instanceof PyClass; + } - @Nullable - private static PyType getReturnType(@Nullable PsiElement element, @Nonnull TypeEvalContext context) { - if (element instanceof PyTypedElement) { - final PyType type = context.getType((PyTypedElement)element); + @RequiredReadAction + public static boolean isIfNameEqualsMain(PyIfStatement ifStatement) { + PyExpression condition = ifStatement.getIfPart().getCondition(); + return isNameEqualsMain(condition); + } - return getReturnType(type, context); + @RequiredReadAction + private static boolean isNameEqualsMain(PyExpression condition) { + if (condition instanceof PyParenthesizedExpression parenExpr) { + return isNameEqualsMain(parenExpr.getContainedExpression()); + } + if (condition instanceof PyBinaryExpression binaryExpression) { + if (binaryExpression.getOperator() == PyTokenTypes.OR_KEYWORD) { + return isNameEqualsMain(binaryExpression.getLeftExpression()) || isNameEqualsMain(binaryExpression.getRightExpression()); + } + PyExpression rhs = binaryExpression.getRightExpression(); + return binaryExpression.getOperator() == PyTokenTypes.EQEQ + && binaryExpression.getLeftExpression().getText().equals(PyNames.NAME) + && rhs != null && rhs.getText().contains("__main__"); + } + return false; } - return null; - } + /** + * Searches for a method wrapping given element. + * + * @param start element presumably inside a method + * @param deep if true, allow 'start' to be inside functions nested in a method; else, 'start' must be directly inside a method. + * @return if not 'deep', [0] is the method and [1] is the class; if 'deep', first several elements may be the nested functions, + * the last but one is the method, and the last is the class. + */ + @Nullable + public static List searchForWrappingMethod(PsiElement start, boolean deep) { + PsiElement seeker = start; + List ret = new ArrayList<>(2); + while (seeker != null) { + PyFunction func = PsiTreeUtil.getParentOfType(seeker, PyFunction.class, true, PyClass.class); + if (func != null) { + PyClass cls = func.getContainingClass(); + if (cls != null) { + ret.add(func); + ret.add(cls); + return ret; + } + else if (deep) { + ret.add(func); + seeker = func; + } + else { + return null; // no immediate class + } + } + else { + return null; // no function + } + } + return null; + } - @Nullable - private static PyType getReturnType(@Nullable PyType type, @Nonnull TypeEvalContext context) { - if (type instanceof PyCallableType) { - return ((PyCallableType)type).getReturnType(context); + public static boolean inSameFile(@Nonnull PsiElement e1, @Nonnull PsiElement e2) { + PsiFile f1 = e1.getContainingFile(); + PsiFile f2 = e2.getContainingFile(); + return !(f1 == null || f2 == null) && f1 == f2; } - if (type instanceof PyUnionType) { - final List types = new ArrayList<>(); + public static boolean onSameLine(@Nonnull PsiElement e1, @Nonnull PsiElement e2) { + PsiDocumentManager documentManager = PsiDocumentManager.getInstance(e1.getProject()); + Document document = documentManager.getDocument(e1.getContainingFile()); + if (document == null || document != documentManager.getDocument(e2.getContainingFile())) { + return false; + } + return document.getLineNumber(e1.getTextOffset()) == document.getLineNumber(e2.getTextOffset()); + } - for (PyType pyType : ((PyUnionType)type).getMembers()) { - final PyType returnType = getReturnType(pyType, context); + public static boolean isTopLevel(@Nonnull PsiElement element) { + if (element instanceof StubBasedPsiElement stubBasedPsiElem) { + StubElement stub = stubBasedPsiElem.getStub(); + if (stub != null) { + StubElement parentStub = stub.getParentStub(); + if (parentStub != null) { + return parentStub.getPsi() instanceof PsiFile; + } + } + } + return ScopeUtil.getScopeOwner(element) instanceof PsiFile; + } - if (returnType != null) { - types.add(returnType); - } - } + public static void deletePycFiles(String pyFilePath) { + if (pyFilePath.endsWith(PyNames.DOT_PY)) { + List filesToDelete = new ArrayList<>(); + File pyc = new File(pyFilePath + "c"); + if (pyc.exists()) { + filesToDelete.add(pyc); + } + File pyo = new File(pyFilePath + "o"); + if (pyo.exists()) { + filesToDelete.add(pyo); + } + File file = new File(pyFilePath); + File pyCache = new File(file.getParentFile(), PyNames.PYCACHE); + if (pyCache.isDirectory()) { + String shortName = FileUtil.getNameWithoutExtension(file); + Collections.addAll(filesToDelete, pyCache.listFiles(pathname -> { + if (!FileUtil.extensionEquals(pathname.getName(), "pyc")) { + return false; + } + String nameWithMagic = FileUtil.getNameWithoutExtension(pathname); + return FileUtil.getNameWithoutExtension(nameWithMagic).equals(shortName); + })); + } + Application.get().getInstance(AsyncFileService.class).asyncDelete(filesToDelete); + } + } - return PyUnionType.union(types); - } - - return null; - } - - public static boolean isEmptyFunction(@Nonnull PyFunction function) { - final PyStatementList statementList = function.getStatementList(); - final PyStatement[] statements = statementList.getStatements(); - if (statements.length == 0) { - return true; + @RequiredReadAction + public static String getElementNameWithoutExtension(PsiNamedElement psiNamedElement) { + return psiNamedElement instanceof PyFile ? FileUtil.getNameWithoutExtension(psiNamedElement.getName()) : psiNamedElement.getName(); } - else if (statements.length == 1) { - if (isStringLiteral(statements[0]) || isPassOrRaiseOrEmptyReturn(statements[0])) { - return true; - } + + public static boolean hasUnresolvedAncestors(@Nonnull PyClass cls, @Nonnull TypeEvalContext context) { + for (PyClassLikeType type : cls.getAncestorTypes(context)) { + if (type == null) { + return true; + } + } + return false; } - else if (statements.length == 2) { - if (isStringLiteral(statements[0]) && (isPassOrRaiseOrEmptyReturn(statements[1]))) { - return true; - } + + @Nonnull + public static AccessDirection getPropertyAccessDirection(@Nonnull PyFunction function) { + Property property = function.getProperty(); + if (property != null) { + if (property.getGetter().valueOrNull() == function) { + return AccessDirection.READ; + } + if (property.getSetter().valueOrNull() == function) { + return AccessDirection.WRITE; + } + else if (property.getDeleter().valueOrNull() == function) { + return AccessDirection.DELETE; + } + } + return AccessDirection.READ; } - return false; - } - private static boolean isPassOrRaiseOrEmptyReturn(PyStatement stmt) { - if (stmt instanceof PyPassStatement || stmt instanceof PyRaiseStatement) { - return true; + @RequiredWriteAction + public static void removeQualifier(@Nonnull PyReferenceExpression element) { + PyExpression qualifier = element.getQualifier(); + if (qualifier == null) { + return; + } + + if (qualifier instanceof PyCallExpression call && call.getCallee() instanceof PyReferenceExpression callee) { + PyExpression calleeQualifier = callee.getQualifier(); + if (calleeQualifier != null) { + qualifier.replace(calleeQualifier); + return; + } + } + PsiElement dot = PyPsiUtils.getNextNonWhitespaceSibling(qualifier); + if (dot != null) { + dot.delete(); + } + qualifier.delete(); } - if (stmt instanceof PyReturnStatement && ((PyReturnStatement)stmt).getExpression() == null) { - return true; + + /** + * Returns string that represents element in string search. + * + * @param element element to search + * @return string that represents element + */ + @Nonnull + @RequiredReadAction + public static String computeElementNameForStringSearch(@Nonnull PsiElement element) { + if (element instanceof PyFile file) { + return FileUtil.getNameWithoutExtension(file.getName()); + } + if (element instanceof PsiDirectory directory) { + return directory.getName(); + } + // Magic literals are always represented by their string values + if (element instanceof PyStringLiteralExpression stringLiteral && PyMagicLiteralTools.isMagicLiteral(element)) { + String name = stringLiteral.getStringValue(); + if (name != null) { + return name; + } + } + if (element instanceof PyElement pyElement) { + String name = pyElement.getName(); + if (name != null) { + return name; + } + } + return element.getNode().getText(); } - return false; - } - private static boolean isStringLiteral(PyStatement stmt) { - if (stmt instanceof PyExpressionStatement) { - final PyExpression expr = ((PyExpressionStatement)stmt).getExpression(); - if (expr instanceof PyStringLiteralExpression) { - return true; - } + public static boolean isOwnScopeComprehension(@Nonnull PyComprehensionElement comprehension) { + boolean isAtLeast30 = LanguageLevel.forElement(comprehension).isAtLeast(LanguageLevel.PYTHON30); + boolean isListComprehension = comprehension instanceof PyListCompExpression; + return !isListComprehension || isAtLeast30; } - return false; - } - /** - * This helper class allows to collect various information about AST nodes composing {@link PyStringLiteralExpression}. - */ - public static final class StringNodeInfo { - private final ASTNode myNode; - private final String myPrefix; - private final String myQuote; - private final TextRange myContentRange; + public static boolean hasCustomDecorators(@Nonnull PyDecoratable decoratable) { + return PyKnownDecoratorUtil.hasNonBuiltinDecorator(decoratable, TypeEvalContext.codeInsightFallback(null)); + } - public StringNodeInfo(@Nonnull ASTNode node) { - if (!PyTokenTypes.STRING_NODES.contains(node.getElementType())) { - throw new IllegalArgumentException("Node must be valid Python string literal token, but " + node.getElementType() + " was given"); - } - myNode = node; - final String nodeText = node.getText(); - final int prefixLength = PyStringLiteralExpressionImpl.getPrefixLength(nodeText); - myPrefix = nodeText.substring(0, prefixLength); - myContentRange = PyStringLiteralExpressionImpl.getNodeTextRange(nodeText); - myQuote = nodeText.substring(prefixLength, myContentRange.getStartOffset()); + public static boolean isDecoratedAsAbstract(@Nonnull PyDecoratable decoratable) { + return PyKnownDecoratorUtil.hasAbstractDecorator(decoratable, TypeEvalContext.codeInsightFallback(null)); } - public StringNodeInfo(@Nonnull PsiElement element) { - this(element.getNode()); + public static ASTNode createNewName(PyElement element, String name) { + return PyElementGenerator.getInstance(element.getProject()).createNameIdentifier(name, LanguageLevel.forElement(element)); } + /** + * Finds element declaration by resolving its references top the top but not further than file (to prevent un-stubbing) + * + * @param elementToResolve element to resolve + * @return its declaration + */ @Nonnull - public ASTNode getNode() { - return myNode; + @RequiredReadAction + public static PsiElement resolveToTheTop(@Nonnull PsiElement elementToResolve) { + PsiElement currentElement = elementToResolve; + Set checkedElements = new HashSet<>(); // To prevent PY-20553 + while (true) { + PsiReference reference = currentElement.getReference(); + if (reference == null) { + break; + } + PsiElement resolve = reference.resolve(); + if (resolve == null + || checkedElements.contains(resolve) + || resolve.equals(currentElement) + || !inSameFile(resolve, currentElement)) { + break; + } + currentElement = resolve; + checkedElements.add(resolve); + } + return currentElement; } /** - * @return string prefix, e.g. "UR", "b" etc. + * Note that returned list may contain {@code null} items, e.g. for unresolved import elements, originally wrapped + * in {@link ImportedResolveResult}. */ @Nonnull - public String getPrefix() { - return myPrefix; + @RequiredReadAction + public static List multiResolveTopPriority(@Nonnull PsiElement element, @Nonnull PyResolveContext resolveContext) { + if (element instanceof PyReferenceOwner refOwner) { + PsiPolyVariantReference ref = refOwner.getReference(resolveContext); + return filterTopPriorityResults(ref.multiResolve(false)); + } + else { + PsiReference reference = element.getReference(); + return reference != null ? Collections.singletonList(reference.resolve()) : Collections.emptyList(); + } + } + + @Nonnull + @RequiredReadAction + public static List multiResolveTopPriority(@Nonnull PsiPolyVariantReference reference) { + return filterTopPriorityResults(reference.multiResolve(false)); + } + + @Nonnull + public static List filterTopPriorityResults(@Nonnull ResolveResult[] resolveResults) { + if (resolveResults.length == 0) { + return Collections.emptyList(); + } + List filtered = new ArrayList<>(); + int maxRate = getMaxRate(resolveResults); + for (ResolveResult resolveResult : resolveResults) { + int rate = resolveResult instanceof RatedResolveResult ratedResolveResult ? ratedResolveResult.getRate() : 0; + if (rate >= maxRate) { + PsiElement element = resolveResult.getElement(); + if (element != null) { + filtered.add(element); + } + } + } + return filtered; + } + + private static int getMaxRate(@Nonnull ResolveResult[] resolveResults) { + int maxRate = Integer.MIN_VALUE; + for (ResolveResult resolveResult : resolveResults) { + if (resolveResult instanceof RatedResolveResult ratedResolveResult) { + int rate = ratedResolveResult.getRate(); + if (rate > maxRate) { + maxRate = rate; + } + } + } + return maxRate; } /** - * @return content of the string node between quotes + * Gets class init method + * + * @param pyClass class where to find init + * @return class init method if any */ - @Nonnull - public String getContent() { - return myContentRange.substring(myNode.getText()); + @Nullable + public static PyFunction getInitMethod(@Nonnull PyClass pyClass) { + return pyClass.findMethodByName(PyNames.INIT, false, null); } /** - * @return relative range of the content (excluding prefix and quotes) - * @see #getAbsoluteContentRange() + * Returns Python language level for a virtual file. + * + * @see {@link LanguageLevel#forElement} */ @Nonnull - public TextRange getContentRange() { - return myContentRange; + @RequiredReadAction + public static LanguageLevel getLanguageLevelForVirtualFile(@Nonnull Project project, @Nonnull VirtualFile virtualFile) { + if (virtualFile instanceof VirtualFileWindow virtualFileWindow) { + virtualFile = virtualFileWindow.getDelegate(); + } + + // Most of the cases should be handled by this one, PyLanguageLevelPusher pushes folders only + VirtualFile folder = virtualFile.getParent(); + if (folder != null) { + LanguageLevel folderLevel = folder.getUserData(LanguageLevel.KEY); + if (folderLevel != null) { + return folderLevel; + } + LanguageLevel fileLevel = PythonLanguageLevelPusher.getFileLanguageLevel(project, virtualFile); + if (fileLevel != null) { + return fileLevel; + } + } + else { + // However this allows us to setup language level per file manually + // in case when it is LightVirtualFile + LanguageLevel level = virtualFile.getUserData(LanguageLevel.KEY); + if (level != null) { + return level; + } + + if (project.getApplication().isUnitTestMode()) { + LanguageLevel languageLevel = LanguageLevel.FORCE_LANGUAGE_LEVEL; + if (languageLevel != null) { + return languageLevel; + } + } + } + return guessLanguageLevelWithCaching(project); + } + + public static void invalidateLanguageLevelCache(@Nonnull Project project) { + project.putUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL, null); + } + + @Nonnull + @RequiredReadAction + public static LanguageLevel guessLanguageLevelWithCaching(@Nonnull Project project) { + LanguageLevel languageLevel = project.getUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL); + if (languageLevel == null) { + languageLevel = guessLanguageLevel(project); + project.putUserData(PythonLanguageLevelPusher.PYTHON_LANGUAGE_LEVEL, languageLevel); + } + + return languageLevel; + } + + @Nonnull + @RequiredReadAction + public static LanguageLevel guessLanguageLevel(@Nonnull Project project) { + ModuleManager moduleManager = ModuleManager.getInstance(project); + LanguageLevel maxLevel = null; + for (Module projectModule : moduleManager.getModules()) { + Sdk sdk = PythonSdkType.findPythonSdk(projectModule); + if (sdk != null) { + LanguageLevel level = PythonSdkType.getLanguageLevelForSdk(sdk); + if (maxLevel == null || maxLevel.isOlderThan(level)) { + maxLevel = level; + } + } + } + if (maxLevel != null) { + return maxLevel; + } + return LanguageLevel.getDefault(); } /** - * @return absolute content range that accounts offset of the {@link #getNode() node} in the document + * Clone of C# "as" operator. + * Checks if expression has correct type and casts it if it has. Returns null otherwise. + * It saves coder from "instanceof / cast" chains. + * + * @param expression expression to check + * @param clazz class to cast + * @param class to cast + * @return expression casted to appropriate type (if could be casted). Null otherwise. */ - @Nonnull - public TextRange getAbsoluteContentRange() { - return getContentRange().shiftRight(myNode.getStartOffset()); + @Nullable + @SuppressWarnings("unchecked") + public static T as(@Nullable Object expression, @Nonnull Class clazz) { + return ObjectUtil.tryCast(expression, clazz); } + // TODO: Move to PsiElement? + /** - * @return the first character of {@link #getQuote()} + * Searches for references injected to element with certain type + * + * @param element element to search injected references for + * @param expectedClass expected type of element reference resolved to + * @param expected type of element reference resolved to + * @return resolved element if found or null if not found */ - public char getSingleQuote() { - return myQuote.charAt(0); + @Nullable + @RequiredReadAction + public static T findReference(@Nonnull PsiElement element, @Nonnull Class expectedClass) { + for (PsiReference reference : element.getReferences()) { + T result = as(reference.resolve(), expectedClass); + if (result != null) { + return result; + } + } + return null; } + /** + * Converts collection to list of certain type + * + * @param expression expression of collection type + * @param elementClass expected element type + * @param expected element type + * @return list of elements of expected element type + */ @Nonnull - public String getQuote() { - return myQuote; + public static List asList(@Nullable Collection expression, @Nonnull Class elementClass) { + if ((expression == null) || expression.isEmpty()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(); + for (Object element : expression) { + T toAdd = as(element, elementClass); + if (toAdd != null) { + result.add(toAdd); + } + } + return result; } - public boolean isTripleQuoted() { - return myQuote.length() == 3; + /** + * Force re-highlighting in all open editors that belong to specified project. + */ + @RequiredUIAccess + public static void rehighlightOpenEditors(@Nonnull Project project) { + project.getApplication().runWriteAction(() -> { + for (Editor editor : EditorFactory.getInstance().getAllEditors()) { + if (editor instanceof EditorEx editorEx && editorEx.getProject() == project) { + VirtualFile vFile = editorEx.getVirtualFile(); + if (vFile != null) { + EditorHighlighter highlighter = EditorHighlighterFactory.getInstance().createEditorHighlighter(project, vFile); + editorEx.setHighlighter(highlighter); + } + } + } + }); + } + + public static T getParameterizedCachedValue(@Nonnull PsiElement element, @Nonnull P param, @Nonnull Function f) { + Map cache = LanguageCachedValueUtil.getCachedValue( + element, + new CachedValueProvider>() { + @Nullable + @Override + public Result> compute() { + return Result.create(Maps.newHashMap(), PsiModificationTracker.MODIFICATION_COUNT); + } + } + ); + T result = cache.get(param); + if (result == null) { + result = f.apply(param); + cache.put(param, result); + } + return result; } /** - * @return true if string literal ends with starting quote + * This method is allowed to be called from any thread, but in general you should not set {@code modal=true} if you're calling it + * from the write action, because in this case {@code function} will be executed right in the current thread (presumably EDT) + * without any progress whatsoever to avoid possible deadlock. */ - public boolean isTerminated() { - final String text = myNode.getText(); - return text.length() - myPrefix.length() >= myQuote.length() * 2 && text.endsWith(myQuote); + public static void runWithProgress( + @Nullable Project project, + @Nls(capitalization = Nls.Capitalization.Title) @Nonnull String title, + boolean modal, + boolean canBeCancelled, + @Nonnull final Consumer function + ) { + if (modal) { + ProgressManager.getInstance().run(new Task.Modal(project, title, canBeCancelled) { + @Override + public void run(@Nonnull ProgressIndicator indicator) { + function.accept(indicator); + } + }); + } + else { + ProgressManager.getInstance().run(new Task.Backgroundable(project, title, canBeCancelled) { + @Override + public void run(@Nonnull ProgressIndicator indicator) { + function.accept(indicator); + } + }); + } } /** - * @return true if given string node contains "u" or "U" prefix + * Executes code only if
_PYCHARM_VERBOSE_MODE
is set in env (which should be done for debug purposes only) + * + * @param runnable code to call */ - public boolean isUnicode() { - return PyStringLiteralUtil.isUnicodePrefix(myPrefix); + public static void verboseOnly(@Nonnull Runnable runnable) { + if (System.getenv().get("_PYCHARM_VERBOSE_MODE") != null) { + runnable.run(); + } } /** - * @return true if given string node contains "r" or "R" prefix + * Returns the line comment that immediately precedes statement list of the given compound statement. Python parser ensures + * that it follows the statement header, i.e. it's directly after the colon, not on its own line. */ - public boolean isRaw() { - return PyStringLiteralUtil.isRawPrefix(myPrefix); + @Nullable + public static PsiComment getCommentOnHeaderLine(@Nonnull PyStatementListContainer container) { + PyStatementList statementList = container.getStatementList(); + return as(PyPsiUtils.getPrevNonWhitespaceSibling(statementList), PsiComment.class); } /** - * @return true if given string node contains "b" or "B" prefix + * If argument is a PsiDirectory, turn it into a PsiFile that points to __init__.py in that directory. + * If there's no __init__.py there, null is returned, there's no point to resolve to a dir which is not a package. + * Alas, resolve() and multiResolve() can't return anything but a PyFile or PsiFileImpl.isPsiUpToDate() would fail. + * This is because isPsiUpToDate() relies on identity of objects returned by FileViewProvider.getPsi(). + * If we ever need to exactly tell a dir from __init__.py, that logic has to change. + * + * @param target a resolve candidate. + * @return a PsiFile if target was a PsiDirectory, or null, or target unchanged. */ - public boolean isBytes() { - return PyStringLiteralUtil.isBytesPrefix(myPrefix); + @Nullable + public static PsiElement turnDirIntoInit(@Nullable PsiElement target) { + if (target instanceof PsiDirectory dir) { + PsiFile file = dir.findFile(PyNames.INIT_DOT_PY); + if (file != null) { + return file; // ResolveImportUtil will extract directory part as needed, everyone else are better off with a file. + } + else { + return null; + } // dir without __init__.py does not resolve + } + else { + return target; + } // don't touch non-dirs } /** - * @return true if given string node contains "f" or "F" prefix + * If directory is a PsiDirectory, that is also a valid Python package, return PsiFile that points to __init__.py, + * if such file exists, or directory itself (i.e. namespace package). Otherwise, return {@code null}. + * Unlike {@link #turnDirIntoInit(PsiElement)} this function handles namespace packages and + * accepts only PsiDirectories as target. + * + * @param directory directory to check + * @param anchor optional PSI element to determine language level as for {@link #isPackage(PsiDirectory, PsiElement)} + * @return PsiFile or PsiDirectory, if target is a Python package and {@code null} null otherwise */ - public boolean isFormatted() { - return PyStringLiteralUtil.isFormattedPrefix(myPrefix); + @Nullable + public static PsiElement getPackageElement(@Nonnull PsiDirectory directory, @Nullable PsiElement anchor) { + if (isPackage(directory, anchor)) { + PsiElement init = turnDirIntoInit(directory); + if (init != null) { + return init; + } + return directory; + } + return null; } /** - * @return true if other string node has the same decorations, i.e. quotes and prefix + * If target is a Python module named __init__.py file, return its directory. Otherwise return target unchanged. + * + * @param target PSI element to check + * @return PsiDirectory or target unchanged */ - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } + @Contract("null -> null; !null -> !null") + @Nullable + @RequiredReadAction + public static PsiElement turnInitIntoDir(@Nullable PsiElement target) { + if (target instanceof PyFile file && isPackage(file)) { + return file.getContainingDirectory(); + } + return target; + } - StringNodeInfo info = (StringNodeInfo)o; + /** + * @see #isPackage(PsiDirectory, boolean, PsiElement) + */ + public static boolean isPackage(@Nonnull PsiDirectory directory, @Nullable PsiElement anchor) { + return isPackage(directory, true, anchor); + } - return getQuote().equals(info.getQuote()) && - isRaw() == info.isRaw() && - isUnicode() == info.isUnicode() && - isBytes() == info.isBytes(); + /** + * Checks that given PsiDirectory can be treated as Python package, i.e. it's either contains __init__.py or it's a namespace package + * (effectively any directory in Python 3.3 and above). Setuptools namespace packages can be checked as well, but it requires access to + * {@link PySetuptoolsNamespaceIndex} and may slow things down during update of project indexes. + * Also note that this method does not check that directory itself and its parents have valid importable names, + * use {@link PyNames#isIdentifier(String)} for this purpose. + * + * @param directory PSI directory to check + * @param checkSetupToolsPackages whether setuptools namespace packages should be considered as well + * @param anchor optional anchor element to determine language level + * @return whether given directory is Python package + * @see PyNames#isIdentifier(String) + */ + @RequiredReadAction + public static boolean isPackage(@Nonnull PsiDirectory directory, boolean checkSetupToolsPackages, @Nullable PsiElement anchor) { + if (directory.findFile(PyNames.INIT_DOT_PY) != null) { + return true; + } + LanguageLevel level = anchor != null + ? LanguageLevel.forElement(anchor) + : getLanguageLevelForVirtualFile(directory.getProject(), directory.getVirtualFile()); + if (level.isAtLeast(LanguageLevel.PYTHON33)) { + return true; + } + return checkSetupToolsPackages && isSetuptoolsNamespacePackage(directory); } - } - public static class IterHelper { // TODO: rename sanely + @RequiredReadAction + public static boolean isPackage(@Nonnull PsiFile file) { + return PyNames.INIT_DOT_PY.equals(file.getName()); + } - private IterHelper() { + private static boolean isSetuptoolsNamespacePackage(@Nonnull PsiDirectory directory) { + String packagePath = getPackagePath(directory); + return packagePath != null && !PySetuptoolsNamespaceIndex.find(packagePath, directory.getProject()).isEmpty(); } @Nullable - public static PsiNamedElement findName(Iterable it, String name) { - PsiNamedElement ret = null; - for (PsiNamedElement elt : it) { - if (elt != null) { - // qualified refs don't match by last name, and we're not checking FQNs here - if (elt instanceof PyQualifiedExpression && ((PyQualifiedExpression)elt).isQualified()) { - continue; - } - if (name.equals(elt.getName())) { // plain name matches - ret = elt; - break; - } - } - } - return ret; - } - } + private static String getPackagePath(@Nonnull PsiDirectory directory) { + QualifiedName name = QualifiedNameFinder.findShortestImportableQName(directory); + return name != null ? name.toString() : null; + } + + /** + * Counts initial underscores of an identifier. + * + * @param name identifier + * @return 0 if no initial underscores found, 1 if there's only one underscore, 2 if there's two or more initial underscores. + */ + public static int getInitialUnderscores(String name) { + if (name == null) { + return 0; + } + int underscores = 0; + if (name.startsWith("__")) { + underscores = 2; + } + else if (name.startsWith("_")) { + underscores = 1; + } + return underscores; + } + + /** + * @param name + * @return true iff the name looks like a class-private one, starting with two underscores but not ending with two underscores. + */ + public static boolean isClassPrivateName(@Nonnull String name) { + return name.startsWith("__") && !name.endsWith("__"); + } + + public static boolean isSpecialName(@Nonnull String name) { + return name.length() > 4 && name.startsWith("__") && name.endsWith("__"); + } + + /** + * Constructs new lookup element for completion of keyword argument with equals sign appended. + * + * @param name name of the parameter + * @param project project instance to check code style settings and surround equals sign with spaces if necessary + * @return lookup element + */ + @Nonnull + public static LookupElement createNamedParameterLookup(@Nonnull String name, @Nullable Project project) { + String suffix; + if (CodeStyleSettingsManager.getSettings(project) + .getCustomSettings(PyCodeStyleSettings.class).SPACE_AROUND_EQ_IN_KEYWORD_ARGUMENT) { + suffix = " = "; + } + else { + suffix = "="; + } + LookupElementBuilder lookupElementBuilder = LookupElementBuilder.create(name + suffix).withIcon(PlatformIconGroup.nodesParameter()); + lookupElementBuilder = lookupElementBuilder.withInsertHandler(OverwriteEqualsInsertHandler.INSTANCE); + return PrioritizedLookupElement.withGrouping(lookupElementBuilder, 1); + } + + /** + * Peels argument expression of parentheses and of keyword argument wrapper + * + * @param expr an item of getArguments() array + * @return expression actually passed as argument + */ + @Nullable + public static PyExpression peelArgument(PyExpression expr) { + while (expr instanceof PyParenthesizedExpression parenExpr) { + expr = parenExpr.getContainedExpression(); + } + if (expr instanceof PyKeywordArgument keywordArg) { + expr = keywordArg.getValueExpression(); + } + return expr; + } + + public static String getFirstParameterName(PyFunction container) { + String selfName = PyNames.CANONICAL_SELF; + if (container != null) { + PyParameter[] params = container.getParameterList().getParameters(); + if (params.length > 0) { + PyNamedParameter named = params[0].getAsNamed(); + if (named != null) { + selfName = named.getName(); + } + } + } + return selfName; + } + + /** + * @return Source roots and content roots for element's project + */ + @Nonnull + @RequiredReadAction + public static Collection getSourceRoots(@Nonnull PsiElement foothold) { + Module module = foothold.getModule(); + if (module != null) { + return getSourceRoots(module); + } + return Collections.emptyList(); + } + + /** + * @return Source roots and content roots for module + */ + @Nonnull + public static Collection getSourceRoots(@Nonnull Module module) { + Set result = new LinkedHashSet<>(); + ModuleRootManager manager = ModuleRootManager.getInstance(module); + Collections.addAll(result, manager.getSourceRoots()); + Collections.addAll(result, manager.getContentRoots()); + return result; + } + + @Nullable + public static VirtualFile findInRoots(Module module, String path) { + if (module != null) { + for (VirtualFile root : getSourceRoots(module)) { + VirtualFile file = root.findFileByRelativePath(path); + if (file != null) { + return file; + } + } + } + return null; + } + + @Nullable + public static List getStringListFromTargetExpression(PyTargetExpression attr) { + return strListValue(attr.findAssignedValue()); + } + + @Nullable + public static List strListValue(PyExpression value) { + while (value instanceof PyParenthesizedExpression parenExpr) { + value = parenExpr.getContainedExpression(); + } + if (value instanceof PySequenceExpression sequenceExpr) { + PyExpression[] elements = sequenceExpr.getElements(); + List result = new ArrayList<>(elements.length); + for (PyExpression element : elements) { + if (!(element instanceof PyStringLiteralExpression)) { + return null; + } + result.add(((PyStringLiteralExpression) element).getStringValue()); + } + return result; + } + return null; + } + + @Nonnull + public static Map dictValue(@Nonnull PyDictLiteralExpression dict) { + Map result = Maps.newLinkedHashMap(); + for (PyKeyValueExpression keyValue : dict.getElements()) { + PyExpression key = keyValue.getKey(); + PyExpression value = keyValue.getValue(); + if (key instanceof PyStringLiteralExpression stringLiteral) { + result.put(stringLiteral.getStringValue(), value); + } + } + return result; + } + + /** + * @param what thing to search for + * @param variants things to search among + * @return true iff what.equals() one of the variants. + */ + @SafeVarargs + public static boolean among(@Nonnull T what, T... variants) { + for (T s : variants) { + if (what.equals(s)) { + return true; + } + } + return false; + } + + @Nullable + public static String getKeywordArgumentString(PyCallExpression expr, String keyword) { + return PyPsiUtils.strValue(expr.getKeywordArgument(keyword)); + } + + public static boolean isExceptionClass(PyClass pyClass) { + if (isBaseException(pyClass.getQualifiedName())) { + return true; + } + for (PyClassLikeType type : pyClass.getAncestorTypes(TypeEvalContext.codeInsightFallback(pyClass.getProject()))) { + if (type != null && isBaseException(type.getClassQName())) { + return true; + } + } + return false; + } + + private static boolean isBaseException(String name) { + return name != null && (name.contains("BaseException") || name.startsWith("exceptions.")); + } + + public static class MethodFlags { + private final boolean myIsStaticMethod; + private final boolean myIsMetaclassMethod; + private final boolean myIsSpecialMetaclassMethod; + private final boolean myIsClassMethod; + + /** + * @return true iff the method belongs to a metaclass (an ancestor of 'type'). + */ + public boolean isMetaclassMethod() { + return myIsMetaclassMethod; + } + + /** + * @return iff isMetaclassMethod and the method is either __init__ or __call__. + */ + public boolean isSpecialMetaclassMethod() { + return myIsSpecialMetaclassMethod; + } + + public boolean isStaticMethod() { + return myIsStaticMethod; + } + + public boolean isClassMethod() { + return myIsClassMethod; + } + + private MethodFlags(boolean isClassMethod, boolean isStaticMethod, boolean isMetaclassMethod, boolean isSpecialMetaclassMethod) { + myIsClassMethod = isClassMethod; + myIsStaticMethod = isStaticMethod; + myIsMetaclassMethod = isMetaclassMethod; + myIsSpecialMetaclassMethod = isSpecialMetaclassMethod; + } + + /** + * @param node a function + * @return a new flags object, or null if the function is not a method + */ + @Nullable + @RequiredReadAction + public static MethodFlags of(@Nonnull PyFunction node) { + PyClass cls = node.getContainingClass(); + if (cls != null) { + PyFunction.Modifier modifier = node.getModifier(); + boolean isMetaclassMethod = false; + PyClass type_cls = PyBuiltinCache.getInstance(node).getClass("type"); + for (PyClass ancestor_cls : cls.getAncestorClasses(null)) { + if (ancestor_cls == type_cls) { + isMetaclassMethod = true; + break; + } + } + String method_name = node.getName(); + boolean isSpecialMetaclassMethod = isMetaclassMethod && method_name != null && among(method_name, PyNames.INIT, "__call__"); + return new MethodFlags(modifier == CLASSMETHOD, modifier == STATICMETHOD, isMetaclassMethod, isSpecialMetaclassMethod); + } + return null; + } + + //TODO: Doc + public boolean isInstanceMethod() { + return !(myIsClassMethod || myIsStaticMethod); + } + } + + @RequiredReadAction + public static boolean isSuperCall(@Nonnull PyCallExpression node) { + PyClass pyClass = PsiTreeUtil.getParentOfType(node, PyClass.class); + if (pyClass == null) { + return false; + } + PyExpression callee = node.getCallee(); + if (callee == null) { + return false; + } + String name = callee.getName(); + if (PyNames.SUPER.equals(name)) { + PsiReference reference = callee.getReference(); + if (reference == null) { + return false; + } + PsiElement resolved = reference.resolve(); + PyBuiltinCache cache = PyBuiltinCache.getInstance(node); + if (resolved != null && cache.isBuiltin(resolved)) { + PyExpression[] args = node.getArguments(); + if (args.length > 0) { + String firstArg = args[0].getText(); + if (firstArg.equals(pyClass.getName()) || firstArg.equals(PyNames.CANONICAL_SELF + "." + PyNames.__CLASS__)) { + return true; + } + for (PyClass s : pyClass.getAncestorClasses(null)) { + if (firstArg.equals(s.getName())) { + return true; + } + } + } + else { + return true; + } + } + } + return false; + } + + @Nonnull + @RequiredReadAction + public static PyFile getOrCreateFile(String path, Project project) { + VirtualFile vFile = LocalFileSystem.getInstance().findFileByIoFile(new File(path)); + PsiFile psi; + if (vFile == null) { + File file = new File(path); + try { + VirtualFile baseDir = project.getBaseDir(); + FileTemplateManager fileTemplateManager = FileTemplateManager.getInstance(project); + FileTemplate template = fileTemplateManager.getInternalTemplate("Python Script"); + Properties properties = fileTemplateManager.getDefaultProperties(); + properties.setProperty("NAME", FileUtil.getNameWithoutExtension(file.getName())); + String content = (template != null) ? template.getText(properties) : null; + psi = PyExtractSuperclassHelper.placeFile( + project, + StringUtil.notNullize(file.getParent(), baseDir != null ? baseDir.getPath() : "."), + file.getName(), + content + ); + } + catch (IOException e) { + throw new IncorrectOperationException(String.format("Cannot create file '%s'", path), e); + } + } + else { + psi = PsiManager.getInstance(project).findFile(vFile); + } + if (psi instanceof PyFile pyFile) { + return pyFile; + } + throw new IncorrectOperationException(PyLocalize.refactoringMoveModuleMembersErrorCannotPlaceElementsIntoNonpythonFile().get()); + } + + /** + * counts elements in iterable + * + * @param expression to count containing elements (iterable) + * @return element count + */ + public static int getElementsCount(PyExpression expression, TypeEvalContext evalContext) { + int valuesLength = -1; + PyType type = evalContext.getType(expression); + if (type instanceof PyTupleType tupleType) { + valuesLength = tupleType.getElementCount(); + } + else if (type instanceof PyNamedTupleType namedTupleType) { + valuesLength = namedTupleType.getElementCount(); + } + else if (expression instanceof PySequenceExpression sequenceExpr) { + valuesLength = sequenceExpr.getElements().length; + } + else if (expression instanceof PyStringLiteralExpression stringLiteral) { + valuesLength = stringLiteral.getStringValue().length(); + } + else if (expression instanceof PyNumericLiteralExpression) { + valuesLength = 1; + } + else if (expression instanceof PyCallExpression call) { + if (call.isCalleeText("dict")) { + valuesLength = call.getArguments().length; + } + else if (call.isCalleeText("tuple")) { + PyExpression[] arguments = call.getArguments(); + if (arguments.length > 0 && arguments[0] instanceof PySequenceExpression sequenceExpr) { + valuesLength = sequenceExpr.getElements().length; + } + } + } + return valuesLength; + } + + @Nullable + @RequiredReadAction + public static PsiElement findPrevAtOffset(PsiFile psiFile, int caretOffset, Class... toSkip) { + PsiElement element; + if (caretOffset < 0) { + return null; + } + int lineStartOffset = 0; + Document document = PsiDocumentManager.getInstance(psiFile.getProject()).getDocument(psiFile); + if (document != null) { + int lineNumber = document.getLineNumber(caretOffset); + lineStartOffset = document.getLineStartOffset(lineNumber); + } + do { + caretOffset--; + element = psiFile.findElementAt(caretOffset); + } + while (caretOffset >= lineStartOffset && instanceOf(element, toSkip)); + return instanceOf(element, toSkip) ? null : element; + } + + @Nullable + @RequiredReadAction + public static PsiElement findNonWhitespaceAtOffset(PsiFile psiFile, int caretOffset) { + PsiElement element = findNextAtOffset(psiFile, caretOffset, PsiWhiteSpace.class); + if (element == null) { + element = findPrevAtOffset(psiFile, caretOffset - 1, PsiWhiteSpace.class); + } + return element; + } + + @Nullable + @RequiredReadAction + public static PsiElement findElementAtOffset(PsiFile psiFile, int caretOffset) { + PsiElement element = findPrevAtOffset(psiFile, caretOffset); + if (element == null) { + element = findNextAtOffset(psiFile, caretOffset); + } + return element; + } + + @Nullable + @RequiredReadAction + public static PsiElement findNextAtOffset(@Nonnull PsiFile psiFile, int caretOffset, Class... toSkip) { + PsiElement element = psiFile.findElementAt(caretOffset); + if (element == null) { + return null; + } + + Document document = PsiDocumentManager.getInstance(psiFile.getProject()).getDocument(psiFile); + int lineEndOffset = 0; + if (document != null) { + int lineNumber = document.getLineNumber(caretOffset); + lineEndOffset = document.getLineEndOffset(lineNumber); + } + while (caretOffset < lineEndOffset && instanceOf(element, toSkip)) { + caretOffset++; + element = psiFile.findElementAt(caretOffset); + } + return instanceOf(element, toSkip) ? null : element; + } + + /** + * Adds element to statement list to the correct place according to its dependencies. + * + * @param element to insert + * @param statementList where element should be inserted + * @return inserted element + */ + @RequiredWriteAction + public static T addElementToStatementList(@Nonnull T element, @Nonnull PyStatementList statementList) { + PsiElement before = null; + PsiElement after = null; + for (PyStatement statement : statementList.getStatements()) { + if (PyDependenciesComparator.depends(element, statement)) { + after = statement; + } + else if (PyDependenciesComparator.depends(statement, element)) { + before = statement; + } + } + PsiElement result; + if (after != null) { + result = statementList.addAfter(element, after); + } + else if (before != null) { + result = statementList.addBefore(element, before); + } + else { + result = addElementToStatementList(element, statementList, true); + } + @SuppressWarnings("unchecked") // Inserted element can't have different type + T resultCasted = (T) result; + return resultCasted; + } + + /** + * Inserts specified element into the statement list either at the beginning or at its end. If new element is going to be + * inserted at the beginning, any preceding docstrings and/or calls to super methods will be skipped. + * Moreover if statement list previously didn't contain any statements, explicit new line and indentation will be inserted in + * front of it. + * + * @param element element to insert + * @param statementList statement list + * @param toTheBeginning whether to insert element at the beginning or at the end of the statement list + * @return actually inserted element as for {@link PsiElement#add(PsiElement)} + */ + @Nonnull + @RequiredWriteAction + public static PsiElement addElementToStatementList( + @Nonnull PsiElement element, + @Nonnull PyStatementList statementList, + boolean toTheBeginning + ) { + PsiElement prevElem = PyPsiUtils.getPrevNonWhitespaceSibling(statementList); + // If statement list is on the same line as previous element (supposedly colon), move its only statement on the next line + if (prevElem != null && onSameLine(statementList, prevElem)) { + PsiDocumentManager manager = PsiDocumentManager.getInstance(statementList.getProject()); + Document document = manager.getDocument(statementList.getContainingFile()); + if (document != null) { + PyStatementListContainer container = (PyStatementListContainer) statementList.getParent(); + manager.doPostponedOperationsAndUnblockDocument(document); + String indentation = "\n" + PyIndentUtil.getElementIndent(statementList); + // If statement list was empty initially, we need to add some anchor statement ("pass"), so that preceding new line was not + // parsed as following entire StatementListContainer (e.g. function). It's going to be replaced anyway. + String text = statementList.getStatements().length == 0 ? indentation + PyNames.PASS : indentation; + document.insertString(statementList.getTextRange().getStartOffset(), text); + manager.commitDocument(document); + statementList = container.getStatementList(); + } + } + PsiElement firstChild = statementList.getFirstChild(); + if (firstChild == statementList.getLastChild() && firstChild instanceof PyPassStatement) { + element = firstChild.replace(element); + } + else { + PyStatement[] statements = statementList.getStatements(); + if (toTheBeginning && statements.length > 0) { + PyDocStringOwner docStringOwner = PsiTreeUtil.getParentOfType(statementList, PyDocStringOwner.class); + PyStatement anchor = statements[0]; + if (docStringOwner != null && anchor instanceof PyExpressionStatement expressionStmt + && expressionStmt.getExpression() == docStringOwner.getDocStringExpression()) { + PyStatement next = PsiTreeUtil.getNextSiblingOfType(expressionStmt, PyStatement.class); + if (next == null) { + return statementList.addAfter(element, anchor); + } + anchor = next; + } + while (anchor instanceof PyExpressionStatement expressionStmt) { + if (expressionStmt.getExpression() instanceof PyCallExpression call) { + PyExpression callee = call.getCallee(); + if (isSuperCall(call) || callee != null && PyNames.INIT.equals(callee.getName())) { + PyStatement next = PsiTreeUtil.getNextSiblingOfType(anchor, PyStatement.class); + if (next == null) { + return statementList.addAfter(element, anchor); + } + anchor = next; + continue; + } + } + break; + } + element = statementList.addBefore(element, anchor); + } + else { + element = statementList.add(element); + } + } + return element; + } + + @Nonnull + public static List> getOverloadedParametersSet(@Nonnull PyCallable callable, @Nonnull TypeEvalContext context) { + List> parametersSet = getOverloadedParametersSet(context.getType(callable), context); + return parametersSet != null + ? parametersSet + : Collections.singletonList(Arrays.asList(callable.getParameterList().getParameters())); + } + + @Nullable + private static List getParametersOfCallableType(@Nonnull PyCallableType type, @Nonnull TypeEvalContext context) { + List callableTypeParameters = type.getParameters(context); + if (callableTypeParameters != null) { + boolean allParametersDefined = true; + List parameters = new ArrayList<>(); + for (PyCallableParameter callableParameter : callableTypeParameters) { + PyParameter parameter = callableParameter.getParameter(); + if (parameter == null) { + allParametersDefined = false; + break; + } + parameters.add(parameter); + } + if (allParametersDefined) { + return parameters; + } + } + return null; + } + + @Nullable + private static List> getOverloadedParametersSet(@Nullable PyType type, @Nonnull TypeEvalContext context) { + if (type instanceof PyUnionType unionType) { + type = unionType.excludeNull(context); + } + + if (type instanceof PyCallableType callableType) { + List results = getParametersOfCallableType(callableType, context); + if (results != null) { + return Collections.singletonList(results); + } + } + else if (type instanceof PyUnionType unionType) { + List> results = new ArrayList<>(); + Collection members = unionType.getMembers(); + for (PyType member : members) { + if (member instanceof PyCallableType callableType) { + List parameters = getParametersOfCallableType(callableType, context); + if (parameters != null) { + results.add(parameters); + } + } + } + if (!results.isEmpty()) { + return results; + } + } + + return null; + } + + @Nonnull + public static List getParameters(@Nonnull PyCallable callable, @Nonnull TypeEvalContext context) { + List> parametersSet = getOverloadedParametersSet(callable, context); + assert !parametersSet.isEmpty(); + return parametersSet.get(0); + } + + public static boolean isSignatureCompatibleTo( + @Nonnull PyCallable callable, + @Nonnull PyCallable otherCallable, + @Nonnull TypeEvalContext context + ) { + List parameters = getParameters(callable, context); + List otherParameters = getParameters(otherCallable, context); + int optionalCount = optionalParametersCount(parameters); + int otherOptionalCount = optionalParametersCount(otherParameters); + int requiredCount = requiredParametersCount(callable, parameters); + int otherRequiredCount = requiredParametersCount(otherCallable, otherParameters); + if (hasPositionalContainer(otherParameters) || hasKeywordContainer(otherParameters)) { + if (otherParameters.size() == specialParametersCount(otherCallable, otherParameters)) { + return true; + } + } + if (hasPositionalContainer(parameters) || hasKeywordContainer(parameters)) { + return requiredCount <= otherRequiredCount; + } + return requiredCount <= otherRequiredCount && parameters.size() >= otherParameters.size() && optionalCount >= otherOptionalCount; + } + + private static int optionalParametersCount(@Nonnull List parameters) { + int n = 0; + for (PyParameter parameter : parameters) { + if (parameter.hasDefaultValue()) { + n++; + } + } + return n; + } + + private static int requiredParametersCount(@Nonnull PyCallable callable, @Nonnull List parameters) { + return parameters.size() - optionalParametersCount(parameters) - specialParametersCount(callable, parameters); + } + + private static int specialParametersCount(@Nonnull PyCallable callable, @Nonnull List parameters) { + int n = 0; + if (hasPositionalContainer(parameters)) { + n++; + } + if (hasKeywordContainer(parameters)) { + n++; + } + if (callable.asMethod() != null) { + n++; + } + else if (parameters.size() > 0) { + PyParameter first = parameters.get(0); + if (PyNames.CANONICAL_SELF.equals(first.getName())) { + n++; + } + } + return n; + } + + private static boolean hasPositionalContainer(@Nonnull List parameters) { + for (PyParameter parameter : parameters) { + if (parameter instanceof PyNamedParameter namedParam && namedParam.isPositionalContainer()) { + return true; + } + } + return false; + } + + private static boolean hasKeywordContainer(@Nonnull List parameters) { + for (PyParameter parameter : parameters) { + if (parameter instanceof PyNamedParameter namedParam && namedParam.isKeywordContainer()) { + return true; + } + } + return false; + } + + @RequiredReadAction + public static boolean isInit(@Nonnull PyFunction function) { + return PyNames.INIT.equals(function.getName()); + } + + /** + * Filters out {@link PyMemberInfo} + * that should not be displayed in this refactoring (like object) + * + * @param pyMemberInfos collection to sort + * @return sorted collection + */ + @Nonnull + public static Collection> filterOutObject(@Nonnull Collection> pyMemberInfos) { + return Collections2.filter(pyMemberInfos, new ObjectPredicate(false)); + } + + public static boolean isStarImportableFrom(@Nonnull String name, @Nonnull PyFile file) { + List dunderAll = file.getDunderAll(); + return dunderAll != null ? dunderAll.contains(name) : !name.startsWith("_"); + } + + /** + * Filters only PyClass object (new class) + */ + public static class ObjectPredicate extends NotNullPredicate> { + private final boolean myAllowObjects; + + /** + * @param allowObjects allows only objects if true. Allows all but objects otherwise. + */ + public ObjectPredicate(boolean allowObjects) { + myAllowObjects = allowObjects; + } + + @Override + public boolean applyNotNull(@Nonnull PyMemberInfo input) { + return myAllowObjects == isObject(input); + } + + private static boolean isObject(@Nonnull PyMemberInfo classMemberInfo) { + PyElement element = classMemberInfo.getMember(); + return element instanceof PyClass && PyNames.OBJECT.equals(element.getName()); + } + } + + /** + * Sometimes you do not know real FQN of some class, but you know class name and its package. + * I.e. django.apps.conf.AppConfig is not documented, but you know + * AppConfig and django package. + * + * @param symbol element to check (class or function) + * @param expectedPackage package like "django" + * @param expectedName expected name (i.e. AppConfig) + * @return true if element in package + */ + public static boolean isSymbolInPackage( + @Nonnull PyQualifiedNameOwner symbol, + @Nonnull String expectedPackage, + @Nonnull String expectedName + ) { + String qualifiedNameString = symbol.getQualifiedName(); + if (qualifiedNameString == null) { + return false; + } + QualifiedName qualifiedName = QualifiedName.fromDottedString(qualifiedNameString); + String aPackage = qualifiedName.getFirstComponent(); + if (!(expectedPackage.equals(aPackage))) { + return false; + } + String symbolName = qualifiedName.getLastComponent(); + return expectedName.equals(symbolName); + } + + /** + * Checks that given class is the root of class hierarchy, i.e. it's either {@code object} or + * special {@link PyNames#FAKE_OLD_BASE} class for old-style classes. + * + * @param cls Python class to check + * @see PyBuiltinCache + * @see PyNames#FAKE_OLD_BASE + */ + public static boolean isObjectClass(@Nonnull PyClass cls) { + PyBuiltinCache builtinCache = PyBuiltinCache.getInstance(cls); + return cls == builtinCache.getClass(PyNames.OBJECT) || cls == builtinCache.getClass(PyNames.FAKE_OLD_BASE); + } + + @RequiredReadAction + public static boolean isInScratchFile(@Nonnull PsiElement element) { + return ScratchFileService.isInScratchRoot(PsiUtilCore.getVirtualFile(element)); + } + + @Nullable + public static PyType getReturnTypeOfMember( + @Nonnull PyType type, + @Nonnull String memberName, + @Nullable PyExpression location, + @Nonnull TypeEvalContext context + ) { + PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); + List resolveResults = + type.resolveMember(memberName, location, AccessDirection.READ, resolveContext); + + if (resolveResults != null) { + List types = new ArrayList<>(); + + for (RatedResolveResult resolveResult : resolveResults) { + PyType returnType = getReturnType(resolveResult.getElement(), context); + + if (returnType != null) { + types.add(returnType); + } + } + + return PyUnionType.union(types); + } + + return null; + } + + @Nullable + private static PyType getReturnType(@Nullable PsiElement element, @Nonnull TypeEvalContext context) { + if (element instanceof PyTypedElement typedElem) { + PyType type = context.getType(typedElem); + + return getReturnType(type, context); + } + + return null; + } + + @Nullable + private static PyType getReturnType(@Nullable PyType type, @Nonnull TypeEvalContext context) { + if (type instanceof PyCallableType callableType) { + return callableType.getReturnType(context); + } + + if (type instanceof PyUnionType unionType) { + List types = new ArrayList<>(); + + for (PyType pyType : unionType.getMembers()) { + PyType returnType = getReturnType(pyType, context); + + if (returnType != null) { + types.add(returnType); + } + } + + return PyUnionType.union(types); + } + + return null; + } + + public static boolean isEmptyFunction(@Nonnull PyFunction function) { + PyStatementList statementList = function.getStatementList(); + PyStatement[] statements = statementList.getStatements(); + if (statements.length == 0) { + return true; + } + else if (statements.length == 1) { + if (isStringLiteral(statements[0]) || isPassOrRaiseOrEmptyReturn(statements[0])) { + return true; + } + } + else if (statements.length == 2) { + if (isStringLiteral(statements[0]) && (isPassOrRaiseOrEmptyReturn(statements[1]))) { + return true; + } + } + return false; + } + + @SuppressWarnings("RedundantIfStatement") + private static boolean isPassOrRaiseOrEmptyReturn(PyStatement stmt) { + if (stmt instanceof PyPassStatement || stmt instanceof PyRaiseStatement) { + return true; + } + if (stmt instanceof PyReturnStatement returnStmt && returnStmt.getExpression() == null) { + return true; + } + return false; + } + + private static boolean isStringLiteral(PyStatement stmt) { + if (stmt instanceof PyExpressionStatement expressionStmt) { + PyExpression expr = expressionStmt.getExpression(); + if (expr instanceof PyStringLiteralExpression) { + return true; + } + } + return false; + } + + /** + * This helper class allows to collect various information about AST nodes composing {@link PyStringLiteralExpression}. + */ + public static final class StringNodeInfo { + private final ASTNode myNode; + private final String myPrefix; + private final String myQuote; + private final TextRange myContentRange; + + public StringNodeInfo(@Nonnull ASTNode node) { + if (!PyTokenTypes.STRING_NODES.contains(node.getElementType())) { + throw new IllegalArgumentException("Node must be valid Python string literal token, but " + node.getElementType() + " was given"); + } + myNode = node; + String nodeText = node.getText(); + int prefixLength = PyStringLiteralExpressionImpl.getPrefixLength(nodeText); + myPrefix = nodeText.substring(0, prefixLength); + myContentRange = PyStringLiteralExpressionImpl.getNodeTextRange(nodeText); + myQuote = nodeText.substring(prefixLength, myContentRange.getStartOffset()); + } + + public StringNodeInfo(@Nonnull PsiElement element) { + this(element.getNode()); + } + + @Nonnull + public ASTNode getNode() { + return myNode; + } + + /** + * @return string prefix, e.g. "UR", "b" etc. + */ + @Nonnull + public String getPrefix() { + return myPrefix; + } + + /** + * @return content of the string node between quotes + */ + @Nonnull + public String getContent() { + return myContentRange.substring(myNode.getText()); + } + + /** + * @return relative range of the content (excluding prefix and quotes) + * @see #getAbsoluteContentRange() + */ + @Nonnull + public TextRange getContentRange() { + return myContentRange; + } + + /** + * @return absolute content range that accounts offset of the {@link #getNode() node} in the document + */ + @Nonnull + public TextRange getAbsoluteContentRange() { + return getContentRange().shiftRight(myNode.getStartOffset()); + } + + /** + * @return the first character of {@link #getQuote()} + */ + public char getSingleQuote() { + return myQuote.charAt(0); + } + + @Nonnull + public String getQuote() { + return myQuote; + } + + public boolean isTripleQuoted() { + return myQuote.length() == 3; + } + + /** + * @return true if string literal ends with starting quote + */ + public boolean isTerminated() { + String text = myNode.getText(); + return text.length() - myPrefix.length() >= myQuote.length() * 2 && text.endsWith(myQuote); + } + + /** + * @return true if given string node contains "u" or "U" prefix + */ + public boolean isUnicode() { + return PyStringLiteralUtil.isUnicodePrefix(myPrefix); + } + + /** + * @return true if given string node contains "r" or "R" prefix + */ + public boolean isRaw() { + return PyStringLiteralUtil.isRawPrefix(myPrefix); + } + + /** + * @return true if given string node contains "b" or "B" prefix + */ + public boolean isBytes() { + return PyStringLiteralUtil.isBytesPrefix(myPrefix); + } + + /** + * @return true if given string node contains "f" or "F" prefix + */ + public boolean isFormatted() { + return PyStringLiteralUtil.isFormattedPrefix(myPrefix); + } + + /** + * @return true if other string node has the same decorations, i.e. quotes and prefix + */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + StringNodeInfo info = (StringNodeInfo) o; + + return getQuote().equals(info.getQuote()) && + isRaw() == info.isRaw() && + isUnicode() == info.isUnicode() && + isBytes() == info.isBytes(); + } + } + + public static class IterHelper { // TODO: rename sanely + private IterHelper() { + } + + @Nullable + @RequiredReadAction + public static PsiNamedElement findName(Iterable it, String name) { + PsiNamedElement ret = null; + for (PsiNamedElement elt : it) { + if (elt != null) { + // qualified refs don't match by last name, and we're not checking FQNs here + if (elt instanceof PyQualifiedExpression qualifiedExpr && qualifiedExpr.isQualified()) { + continue; + } + if (name.equals(elt.getName())) { // plain name matches + ret = elt; + break; + } + } + } + return ret; + } + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java index dddb3732..dd4aae8c 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.jetbrains.python.impl.psi.impl; import com.jetbrains.python.psi.PyFunction; +import consulo.annotation.access.RequiredReadAction; +import consulo.language.psi.PsiElement; /** * @author yole @@ -25,4 +26,10 @@ public class PyBoundFunction extends PyFunctionImpl { public PyBoundFunction(PyFunction function) { super(function.getNode()); } + + @RequiredReadAction + @Override + public PsiElement getNameIdentifier() { + return super.getNameIdentifier(); + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFileImpl.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFileImpl.java index c8938da7..dc6a82e7 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFileImpl.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFileImpl.java @@ -36,9 +36,9 @@ import com.jetbrains.python.psi.stubs.PyFileStub; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; +import consulo.annotation.access.RequiredWriteAction; import consulo.application.util.RecursionManager; -import consulo.application.util.function.Processor; import consulo.ide.impl.idea.openapi.vfs.VfsUtilCore; import consulo.language.Language; import consulo.language.file.FileViewProvider; @@ -48,781 +48,811 @@ import consulo.language.psi.resolve.PsiScopeProcessor; import consulo.language.psi.resolve.ResolveState; import consulo.language.psi.stub.IndexingDataKeys; -import consulo.language.psi.stub.StubElement; import consulo.language.psi.util.PsiTreeUtil; import consulo.language.psi.util.QualifiedName; import consulo.language.util.IncorrectOperationException; import consulo.module.content.ProjectFileIndex; import consulo.navigation.ItemPresentation; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.image.Image; import consulo.util.collection.ContainerUtil; import consulo.util.dataholder.Key; import consulo.util.io.FileUtil; import consulo.util.lang.ref.SoftReference; import consulo.virtualFileSystem.VirtualFile; - import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.io.File; import java.util.*; +import java.util.function.Predicate; public class PyFileImpl extends PsiFileBase implements PyFile, PyExpression { - protected PyType myType; - - //private volatile Boolean myAbsoluteImportEnabled; - private final Map myFutureFeatures; - private List myDunderAll; - private boolean myDunderAllCalculated; - private volatile SoftReference myExportedNameCache = new SoftReference<>(null); - private final PsiModificationTracker myModificationTracker; - - private class ExportedNameCache { - private final List myNameDefinerNegativeCache = new ArrayList<>(); - private long myNameDefinerOOCBModCount = -1; - private final long myModificationStamp; - private final Map> myNamedElements = Maps.newHashMap(); - private final List myImportedNameDefiners = Lists.newArrayList(); - - private ExportedNameCache(long modificationStamp) { - myModificationStamp = modificationStamp; - - processDeclarations(PyPsiUtils.collectAllStubChildren(PyFileImpl.this, getStub()), element -> { - if (element instanceof PsiNamedElement && !(element instanceof PyKeywordArgument)) { - final PsiNamedElement namedElement = (PsiNamedElement)element; - final String name = namedElement.getName(); - if (!myNamedElements.containsKey(name)) { - myNamedElements.put(name, Lists.newArrayList()); - } - final List elements = myNamedElements.get(name); - elements.add(namedElement); - } - if (element instanceof PyImportedNameDefiner) { - myImportedNameDefiners.add((PyImportedNameDefiner)element); - } - if (element instanceof PyFromImportStatement) { - final PyFromImportStatement fromImportStatement = (PyFromImportStatement)element; - final PyStarImportElement starImportElement = fromImportStatement.getStarImportElement(); - if (starImportElement != null) { - myImportedNameDefiners.add(starImportElement); - } - else { - Collections.addAll(myImportedNameDefiners, fromImportStatement.getImportElements()); - } - } - else if (element instanceof PyImportStatement) { - final PyImportStatement importStatement = (PyImportStatement)element; - Collections.addAll(myImportedNameDefiners, importStatement.getImportElements()); + protected PyType myType; + + //private volatile Boolean myAbsoluteImportEnabled; + private final Map myFutureFeatures; + private List myDunderAll; + private boolean myDunderAllCalculated; + private volatile SoftReference myExportedNameCache = new SoftReference<>(null); + private final PsiModificationTracker myModificationTracker; + + private class ExportedNameCache { + private final List myNameDefinerNegativeCache = new ArrayList<>(); + private long myNameDefinerOOCBModCount = -1; + private final long myModificationStamp; + private final Map> myNamedElements = Maps.newHashMap(); + private final List myImportedNameDefiners = Lists.newArrayList(); + + @RequiredReadAction + private ExportedNameCache(long modificationStamp) { + myModificationStamp = modificationStamp; + + processDeclarations( + PyPsiUtils.collectAllStubChildren(PyFileImpl.this, getStub()), + element -> { + if (element instanceof PsiNamedElement namedElement && !(namedElement instanceof PyKeywordArgument)) { + String name = namedElement.getName(); + if (!myNamedElements.containsKey(name)) { + myNamedElements.put(name, new ArrayList<>()); + } + List elements = myNamedElements.get(name); + elements.add(namedElement); + } + if (element instanceof PyImportedNameDefiner importedNameDefiner) { + myImportedNameDefiners.add(importedNameDefiner); + } + if (element instanceof PyFromImportStatement fromImportStatement) { + PyStarImportElement starImportElement = fromImportStatement.getStarImportElement(); + if (starImportElement != null) { + myImportedNameDefiners.add(starImportElement); + } + else { + Collections.addAll(myImportedNameDefiners, fromImportStatement.getImportElements()); + } + } + else if (element instanceof PyImportStatement importStatement) { + Collections.addAll(myImportedNameDefiners, importStatement.getImportElements()); + } + return true; + } + ); + for (List elements : myNamedElements.values()) { + Collections.reverse(elements); + } + Collections.reverse(myImportedNameDefiners); } - return true; - }); - for (List elements : myNamedElements.values()) { - Collections.reverse(elements); - } - Collections.reverse(myImportedNameDefiners); - } - private boolean processDeclarations(@Nonnull List elements, @Nonnull Processor processor) { - for (PsiElement child : elements) { - if (!processor.process(child)) { - return false; + private boolean processDeclarations(@Nonnull List elements, @Nonnull Predicate processor) { + for (PsiElement child : elements) { + if (!processor.test(child)) { + return false; + } + if (child instanceof PyExceptPart part + && !processDeclarations(PyPsiUtils.collectAllStubChildren(part, part.getStub()), processor)) { + return false; + } + } + return true; + } + + @Nonnull + private List multiResolve(@Nonnull String name) { + synchronized (myNameDefinerNegativeCache) { + long modCount = myModificationTracker.getOutOfCodeBlockModificationCount(); + if (modCount != myNameDefinerOOCBModCount) { + myNameDefinerNegativeCache.clear(); + myNameDefinerOOCBModCount = modCount; + } + else if (myNameDefinerNegativeCache.contains(name)) { + return Collections.emptyList(); + } + } + + PyResolveProcessor processor = new PyResolveProcessor(name); + boolean stopped = false; + if (myNamedElements.containsKey(name)) { + for (PsiNamedElement element : myNamedElements.get(name)) { + if (!processor.execute(element, ResolveState.initial())) { + stopped = true; + break; + } + } + } + if (!stopped) { + for (PyImportedNameDefiner definer : myImportedNameDefiners) { + if (!processor.execute(definer, ResolveState.initial())) { + break; + } + } + } + Map results = processor.getResults(); + if (!results.isEmpty()) { + ResolveResultList resultList = new ResolveResultList(); + TypeEvalContext typeEvalContext = TypeEvalContext.codeInsightFallback(getProject()); + for (Map.Entry entry : results.entrySet()) { + PsiElement element = entry.getKey(); + PyImportedNameDefiner definer = entry.getValue(); + if (element != null) { + int elementRate = PyReferenceImpl.getRate(element, typeEvalContext); + if (definer != null) { + resultList.add(new ImportedResolveResult(element, elementRate, definer)); + } + else { + resultList.poke(element, elementRate); + } + } + } + return resultList; + } + + synchronized (myNameDefinerNegativeCache) { + myNameDefinerNegativeCache.add(name); + } + return Collections.emptyList(); } - if (child instanceof PyExceptPart) { - final PyExceptPart part = (PyExceptPart)child; - if (!processDeclarations(PyPsiUtils.collectAllStubChildren(part, part.getStub()), processor)) { - return false; - } + + public long getModificationStamp() { + return myModificationStamp; } - } - return true; + } + + public PyFileImpl(FileViewProvider viewProvider) { + this(viewProvider, PythonLanguage.getInstance()); + } + + public PyFileImpl(FileViewProvider viewProvider, Language language) { + super(viewProvider, language); + myFutureFeatures = new HashMap<>(); + myModificationTracker = PsiModificationTracker.SERVICE.getInstance(getProject()); + } + + @Override + @RequiredReadAction + public String toString() { + return "PyFile:" + getName(); + } + + @Override + @RequiredReadAction + public PyFunction findTopLevelFunction(String name) { + return findByName(name, getTopLevelFunctions()); + } + + @Override + @RequiredReadAction + public PyClass findTopLevelClass(String name) { + return findByName(name, getTopLevelClasses()); + } + + @Override + @RequiredReadAction + public PyTargetExpression findTopLevelAttribute(String name) { + return findByName(name, getTopLevelAttributes()); + } + + @Nullable + @RequiredReadAction + private static T findByName(String name, List namedElements) { + for (T namedElement : namedElements) { + if (name.equals(namedElement.getName())) { + return namedElement; + } + } + return null; } @Nonnull - private List multiResolve(@Nonnull String name) { - synchronized (myNameDefinerNegativeCache) { - final long modCount = myModificationTracker.getOutOfCodeBlockModificationCount(); - if (modCount != myNameDefinerOOCBModCount) { - myNameDefinerNegativeCache.clear(); - myNameDefinerOOCBModCount = modCount; + @Override + @RequiredReadAction + public LanguageLevel getLanguageLevel() { + if (myOriginalFile != null) { + return ((PyFileImpl) myOriginalFile).getLanguageLevel(); } - else { - if (myNameDefinerNegativeCache.contains(name)) { - return Collections.emptyList(); - } - } - } - - final PyResolveProcessor processor = new PyResolveProcessor(name); - boolean stopped = false; - if (myNamedElements.containsKey(name)) { - for (PsiNamedElement element : myNamedElements.get(name)) { - if (!processor.execute(element, ResolveState.initial())) { - stopped = true; - break; - } - } - } - if (!stopped) { - for (PyImportedNameDefiner definer : myImportedNameDefiners) { - if (!processor.execute(definer, ResolveState.initial())) { - break; - } - } - } - final Map results = processor.getResults(); - if (!results.isEmpty()) { - final ResolveResultList resultList = new ResolveResultList(); - final TypeEvalContext typeEvalContext = TypeEvalContext.codeInsightFallback(getProject()); - for (Map.Entry entry : results.entrySet()) { - final PsiElement element = entry.getKey(); - final PyImportedNameDefiner definer = entry.getValue(); - if (element != null) { - final int elementRate = PyReferenceImpl.getRate(element, typeEvalContext); - if (definer != null) { - resultList.add(new ImportedResolveResult(element, elementRate, definer)); + VirtualFile virtualFile = getVirtualFile(); + + if (virtualFile == null) { + virtualFile = getUserData(IndexingDataKeys.VIRTUAL_FILE); + } + if (virtualFile == null) { + virtualFile = getViewProvider().getVirtualFile(); + } + return PyUtil.getLanguageLevelForVirtualFile(getProject(), virtualFile); + } + + @Override + public void accept(@Nonnull PsiElementVisitor visitor) { + if (isAcceptedFor(visitor.getClass())) { + if (visitor instanceof PyElementVisitor elemVisitor) { + elemVisitor.visitPyFile(this); } else { - resultList.poke(element, elementRate); + super.accept(visitor); } - } } - return resultList; - } + } - synchronized (myNameDefinerNegativeCache) { - myNameDefinerNegativeCache.add(name); - } - return Collections.emptyList(); - } - - public long getModificationStamp() { - return myModificationStamp; + public boolean isAcceptedFor(@Nonnull Class visitorClass) { + for (Language lang : getViewProvider().getLanguages()) { + List filters = PythonVisitorFilter.forLanguage(lang); + for (PythonVisitorFilter filter : filters) { + if (!filter.isSupported(visitorClass, this)) { + return false; + } + } + } + return true; } - } - public PyFileImpl(FileViewProvider viewProvider) { - this(viewProvider, PythonLanguage.getInstance()); - } + private final Key> PROCESSED_FILES = Key.create("PyFileImpl.processDeclarations.processedFiles"); - public PyFileImpl(FileViewProvider viewProvider, Language language) { - super(viewProvider, language); - myFutureFeatures = new HashMap<>(); - myModificationTracker = PsiModificationTracker.SERVICE.getInstance(getProject()); - } + @Override + public boolean processDeclarations( + @Nonnull final PsiScopeProcessor processor, + @Nonnull ResolveState resolveState, + PsiElement lastParent, + @Nonnull PsiElement place + ) { + List dunderAll = getDunderAll(); + final List remainingDunderAll = dunderAll == null ? null : new ArrayList<>(dunderAll); + PsiScopeProcessor wrapper = new PsiScopeProcessor() { + @Override + public boolean execute(@Nonnull PsiElement element, @Nonnull ResolveState state) { + if (!processor.execute(element, state)) { + return false; + } + if (remainingDunderAll != null && element instanceof PyElement pyElement) { + remainingDunderAll.remove(pyElement.getName()); + } + return true; + } - public String toString() { - return "PyFile:" + getName(); - } + @Override + public T getHint(@Nonnull Key hintKey) { + return processor.getHint(hintKey); + } - @Override - public PyFunction findTopLevelFunction(String name) { - return findByName(name, getTopLevelFunctions()); - } + @Override + public void handleEvent(@Nonnull Event event, @Nullable Object associated) { + processor.handleEvent(event, associated); + } + }; + + Set pyFiles = resolveState.get(PROCESSED_FILES); + if (pyFiles == null) { + pyFiles = new HashSet<>(); + resolveState = resolveState.put(PROCESSED_FILES, pyFiles); + } + if (pyFiles.contains(this)) { + return true; + } + pyFiles.add(this); + for (PyClass c : getTopLevelClasses()) { + if (c == lastParent) { + continue; + } + if (!wrapper.execute(c, resolveState)) { + return false; + } + } + for (PyFunction f : getTopLevelFunctions()) { + if (f == lastParent) { + continue; + } + if (!wrapper.execute(f, resolveState)) { + return false; + } + } + for (PyTargetExpression e : getTopLevelAttributes()) { + if (e == lastParent) { + continue; + } + if (!wrapper.execute(e, resolveState)) { + return false; + } + } - @Override - public PyClass findTopLevelClass(String name) { - return findByName(name, getTopLevelClasses()); - } + for (PyImportElement e : getImportTargets()) { + if (e == lastParent) { + continue; + } + if (!wrapper.execute(e, resolveState)) { + return false; + } + } - @Override - public PyTargetExpression findTopLevelAttribute(String name) { - return findByName(name, getTopLevelAttributes()); - } + for (PyFromImportStatement e : getFromImports()) { + if (e == lastParent) { + continue; + } + if (!e.processDeclarations(wrapper, resolveState, null, this)) { + return false; + } + } - @Nullable - private static T findByName(String name, List namedElements) { - for (T namedElement : namedElements) { - if (name.equals(namedElement.getName())) { - return namedElement; - } + if (remainingDunderAll != null) { + for (String s : remainingDunderAll) { + if (!PyNames.isIdentifier(s)) { + continue; + } + if (!processor.execute(new LightNamedElement(myManager, PythonLanguage.getInstance(), s), resolveState)) { + return false; + } + } + } + return true; } - return null; - } - @Override - public LanguageLevel getLanguageLevel() { - if (myOriginalFile != null) { - return ((PyFileImpl)myOriginalFile).getLanguageLevel(); + @Override + @RequiredReadAction + public List getStatements() { + List stmts = new ArrayList<>(); + for (PsiElement child : getChildren()) { + if (child instanceof PyStatement statement) { + stmts.add(statement); + } + } + return stmts; } - VirtualFile virtualFile = getVirtualFile(); - if (virtualFile == null) { - virtualFile = getUserData(IndexingDataKeys.VIRTUAL_FILE); + @Override + @RequiredReadAction + public List getTopLevelClasses() { + return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.CLASS_DECLARATION, PyClass.class); } - if (virtualFile == null) { - virtualFile = getViewProvider().getVirtualFile(); + + @Nonnull + @Override + @RequiredReadAction + public List getTopLevelFunctions() { + return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.FUNCTION_DECLARATION, PyFunction.class); } - return PyUtil.getLanguageLevelForVirtualFile(getProject(), virtualFile); - } - @Override - public void accept(@Nonnull PsiElementVisitor visitor) { - if (isAcceptedFor(visitor.getClass())) { - if (visitor instanceof PyElementVisitor) { - ((PyElementVisitor)visitor).visitPyFile(this); - } - else { - super.accept(visitor); - } + @Override + @RequiredReadAction + public List getTopLevelAttributes() { + return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.TARGET_EXPRESSION, PyTargetExpression.class); } - } - public boolean isAcceptedFor(@Nonnull Class visitorClass) { - for (Language lang : getViewProvider().getLanguages()) { - final List filters = PythonVisitorFilter.forLanguage(lang); - for (PythonVisitorFilter filter : filters) { - if (!filter.isSupported(visitorClass, this)) { - return false; + @Nullable + @Override + @RequiredReadAction + public PsiElement findExportedName(String name) { + List results = multiResolveName(name); + List elements = Lists.newArrayList(); + for (RatedResolveResult result : results) { + PsiElement element = result.getElement(); + ImportedResolveResult importedResult = PyUtil.as(result, ImportedResolveResult.class); + if (importedResult != null) { + PyImportedNameDefiner definer = importedResult.getDefiner(); + if (definer != null) { + elements.add(definer); + } + } + else if (element != null && element.getContainingFile() == this) { + elements.add(element); + } + } + PsiElement element = elements.isEmpty() ? null : elements.get(elements.size() - 1); + if (element != null && !element.isValid()) { + throw new PsiInvalidElementAccessException(element); } - } + return element; } - return true; - } - private final Key> PROCESSED_FILES = Key.create("PyFileImpl.processDeclarations.processedFiles"); - - @Override - public boolean processDeclarations(@Nonnull final PsiScopeProcessor processor, - @Nonnull ResolveState resolveState, - PsiElement lastParent, - @Nonnull PsiElement place) { - final List dunderAll = getDunderAll(); - final List remainingDunderAll = dunderAll == null ? null : new ArrayList<>(dunderAll); - PsiScopeProcessor wrapper = new PsiScopeProcessor() { - @Override - public boolean execute(@Nonnull PsiElement element, @Nonnull ResolveState state) { - if (!processor.execute(element, state)) { - return false; - } - if (remainingDunderAll != null && element instanceof PyElement) { - remainingDunderAll.remove(((PyElement)element).getName()); + @Nonnull + @Override + @RequiredReadAction + public List multiResolveName(@Nonnull String name) { + List results = + RecursionManager.doPreventingRecursion(this, false, () -> getExportedNameCache().multiResolve(name)); + if (results != null && !results.isEmpty()) { + return results; } - return true; - } - - @Override - public T getHint(@Nonnull Key hintKey) { - return processor.getHint(hintKey); - } - - @Override - public void handleEvent(@Nonnull Event event, @Nullable Object associated) { - processor.handleEvent(event, associated); - } - }; - - Set pyFiles = resolveState.get(PROCESSED_FILES); - if (pyFiles == null) { - pyFiles = new HashSet<>(); - resolveState = resolveState.put(PROCESSED_FILES, pyFiles); - } - if (pyFiles.contains(this)) { - return true; - } - pyFiles.add(this); - for (PyClass c : getTopLevelClasses()) { - if (c == lastParent) { - continue; - } - if (!wrapper.execute(c, resolveState)) { - return false; - } + List allNames = getDunderAll(); + if (allNames != null && allNames.contains(name)) { + PsiElement allElement = findExportedName(PyNames.ALL); + ResolveResultList allFallbackResults = new ResolveResultList(); + allFallbackResults.poke(allElement, RatedResolveResult.RATE_LOW); + return allFallbackResults; + } + return Collections.emptyList(); } - for (PyFunction f : getTopLevelFunctions()) { - if (f == lastParent) { - continue; - } - if (!wrapper.execute(f, resolveState)) { - return false; - } + + @RequiredReadAction + private ExportedNameCache getExportedNameCache() { + ExportedNameCache cache; + cache = myExportedNameCache != null ? myExportedNameCache.get() : null; + long modificationStamp = getModificationStamp(); + if (myExportedNameCache != null && cache != null && modificationStamp != cache.getModificationStamp()) { + myExportedNameCache.clear(); + cache = null; + } + if (cache == null) { + cache = new ExportedNameCache(modificationStamp); + myExportedNameCache = new SoftReference<>(cache); + } + return cache; } - for (PyTargetExpression e : getTopLevelAttributes()) { - if (e == lastParent) { - continue; - } - if (!wrapper.execute(e, resolveState)) { - return false; - } + + @Nullable + @Override + @RequiredReadAction + public PsiElement getElementNamed(String name) { + List results = multiResolveName(name); + List elements = PyUtil.filterTopPriorityResults(results.toArray(new ResolveResult[results.size()])); + PsiElement element = elements.isEmpty() ? null : elements.get(elements.size() - 1); + if (element != null) { + if (!element.isValid()) { + throw new PsiInvalidElementAccessException(element); + } + return element; + } + return null; } - for (PyImportElement e : getImportTargets()) { - if (e == lastParent) { - continue; - } - if (!wrapper.execute(e, resolveState)) { - return false; - } + @Nonnull + @Override + @RequiredReadAction + public Iterable iterateNames() { + final List result = new ArrayList<>(); + VariantsProcessor processor = new VariantsProcessor(this) { + @Override + protected void addElement(String name, PsiElement element) { + if (PyUtil.turnDirIntoInit(element) instanceof PyElement pyElement) { + result.add(pyElement); + } + } + }; + processor.setAllowedNames(getDunderAll()); + processDeclarations(processor, ResolveState.initial(), null, this); + return result; } - for (PyFromImportStatement e : getFromImports()) { - if (e == lastParent) { - continue; - } - if (!e.processDeclarations(wrapper, resolveState, null, this)) { - return false; - } - } - - if (remainingDunderAll != null) { - for (String s : remainingDunderAll) { - if (!PyNames.isIdentifier(s)) { - continue; - } - if (!processor.execute(new LightNamedElement(myManager, PythonLanguage.getInstance(), s), resolveState)) { - return false; - } - } - } - return true; - } - - @Override - public List getStatements() { - List stmts = new ArrayList<>(); - for (PsiElement child : getChildren()) { - if (child instanceof PyStatement) { - PyStatement statement = (PyStatement)child; - stmts.add(statement); - } - } - return stmts; - } - - @Override - public List getTopLevelClasses() { - return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.CLASS_DECLARATION, PyClass.class); - } - - @Nonnull - @Override - public List getTopLevelFunctions() { - return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.FUNCTION_DECLARATION, PyFunction.class); - } - - @Override - public List getTopLevelAttributes() { - return PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.TARGET_EXPRESSION, PyTargetExpression.class); - } - - @Override - @Nullable - public PsiElement findExportedName(final String name) { - final List results = multiResolveName(name); - final List elements = Lists.newArrayList(); - for (RatedResolveResult result : results) { - final PsiElement element = result.getElement(); - final ImportedResolveResult importedResult = PyUtil.as(result, ImportedResolveResult.class); - if (importedResult != null) { - final PyImportedNameDefiner definer = importedResult.getDefiner(); - if (definer != null) { - elements.add(definer); - } - } - else if (element != null && element.getContainingFile() == this) { - elements.add(element); - } - } - final PsiElement element = elements.isEmpty() ? null : elements.get(elements.size() - 1); - if (element != null && !element.isValid()) { - throw new PsiInvalidElementAccessException(element); - } - return element; - } - - @Nonnull - @Override - public List multiResolveName(@Nonnull final String name) { - final List results = - RecursionManager.doPreventingRecursion(this, false, () -> getExportedNameCache().multiResolve(name)); - if (results != null && !results.isEmpty()) { - return results; - } - final List allNames = getDunderAll(); - if (allNames != null && allNames.contains(name)) { - final PsiElement allElement = findExportedName(PyNames.ALL); - final ResolveResultList allFallbackResults = new ResolveResultList(); - allFallbackResults.poke(allElement, RatedResolveResult.RATE_LOW); - return allFallbackResults; - } - return Collections.emptyList(); - } - - private ExportedNameCache getExportedNameCache() { - ExportedNameCache cache; - cache = myExportedNameCache != null ? myExportedNameCache.get() : null; - final long modificationStamp = getModificationStamp(); - if (myExportedNameCache != null && cache != null && modificationStamp != cache.getModificationStamp()) { - myExportedNameCache.clear(); - cache = null; - } - if (cache == null) { - cache = new ExportedNameCache(modificationStamp); - myExportedNameCache = new SoftReference<>(cache); - } - return cache; - } - - @Nullable - public PsiElement getElementNamed(final String name) { - final List results = multiResolveName(name); - final List elements = PyUtil.filterTopPriorityResults(results.toArray(new ResolveResult[results.size()])); - final PsiElement element = elements.isEmpty() ? null : elements.get(elements.size() - 1); - if (element != null) { - if (!element.isValid()) { - throw new PsiInvalidElementAccessException(element); - } - return element; - } - return null; - } - - @Nonnull - public Iterable iterateNames() { - final List result = new ArrayList<>(); - VariantsProcessor processor = new VariantsProcessor(this) { - @Override - protected void addElement(String name, PsiElement element) { - element = PyUtil.turnDirIntoInit(element); - if (element instanceof PyElement) { - result.add((PyElement)element); - } - } - }; - processor.setAllowedNames(getDunderAll()); - processDeclarations(processor, ResolveState.initial(), null, this); - return result; - } - - @Override - @Nonnull - public List getImportTargets() { - List ret = new ArrayList<>(); - List imports = - PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.IMPORT_STATEMENT, PyImportStatement.class); - for (PyImportStatement one : imports) { - ContainerUtil.addAll(ret, one.getImportElements()); - } - return ret; - } - - @Override - @Nonnull - public List getFromImports() { - return PyPsiUtils.collectStubChildren(this, getStub(), PyElementTypes.FROM_IMPORT_STATEMENT, PyFromImportStatement.class); - } - - @Override - public List getDunderAll() { - final StubElement stubElement = getStub(); - if (stubElement instanceof PyFileStub) { - return ((PyFileStub)stubElement).getDunderAll(); - } - if (!myDunderAllCalculated) { - final List dunderAll = calculateDunderAll(); - myDunderAll = dunderAll == null ? null : Collections.unmodifiableList(dunderAll); - myDunderAllCalculated = true; - } - return myDunderAll; - } - - @Nullable - public List calculateDunderAll() { - final DunderAllBuilder builder = new DunderAllBuilder(); - accept(builder); - return builder.result(); - } - - private static class DunderAllBuilder extends PyRecursiveElementVisitor { - private List myResult = null; - private boolean myDynamic = false; - private boolean myFoundDunderAll = false; - - // hashlib builds __all__ by concatenating multiple lists of strings, and we want to understand this - private final Map> myDunderLike = new HashMap<>(); + @Nonnull + @Override + @RequiredReadAction + public List getImportTargets() { + List ret = new ArrayList<>(); + List imports = + PyPsiUtils.collectStubChildren(this, this.getStub(), PyElementTypes.IMPORT_STATEMENT, PyImportStatement.class); + for (PyImportStatement one : imports) { + ContainerUtil.addAll(ret, one.getImportElements()); + } + return ret; + } + @Nonnull @Override - public void visitPyFile(PyFile node) { - if (node.getText().contains(PyNames.ALL)) { - super.visitPyFile(node); - } + @RequiredReadAction + public List getFromImports() { + return PyPsiUtils.collectStubChildren(this, getStub(), PyElementTypes.FROM_IMPORT_STATEMENT, PyFromImportStatement.class); } @Override - public void visitPyTargetExpression(PyTargetExpression node) { - if (PyNames.ALL.equals(node.getName())) { - myFoundDunderAll = true; - PyExpression value = node.findAssignedValue(); - if (value instanceof PyBinaryExpression) { - PyBinaryExpression binaryExpression = (PyBinaryExpression)value; - if (binaryExpression.isOperator("+")) { - List lhs = getStringListFromValue(binaryExpression.getLeftExpression()); - List rhs = getStringListFromValue(binaryExpression.getRightExpression()); - if (lhs != null && rhs != null) { - myResult = new ArrayList<>(lhs); - myResult.addAll(rhs); - } - } + @RequiredReadAction + public List getDunderAll() { + if (getStub() instanceof PyFileStub fileStub) { + return fileStub.getDunderAll(); } - else { - myResult = PyUtil.getStringListFromTargetExpression(node); + if (!myDunderAllCalculated) { + List dunderAll = calculateDunderAll(); + myDunderAll = dunderAll == null ? null : Collections.unmodifiableList(dunderAll); + myDunderAllCalculated = true; } - } - if (!myFoundDunderAll) { - List names = PyUtil.getStringListFromTargetExpression(node); - if (names != null) { - myDunderLike.put(node.getName(), names); + return myDunderAll; + } + + @Nullable + public List calculateDunderAll() { + DunderAllBuilder builder = new DunderAllBuilder(); + accept(builder); + return builder.result(); + } + + private static class DunderAllBuilder extends PyRecursiveElementVisitor { + private List myResult = null; + private boolean myDynamic = false; + private boolean myFoundDunderAll = false; + + // hashlib builds __all__ by concatenating multiple lists of strings, and we want to understand this + private final Map> myDunderLike = new HashMap<>(); + + @Override + @RequiredReadAction + public void visitPyFile(PyFile node) { + if (node.getText().contains(PyNames.ALL)) { + super.visitPyFile(node); + } + } + + @Override + public void visitPyTargetExpression(PyTargetExpression node) { + if (PyNames.ALL.equals(node.getName())) { + myFoundDunderAll = true; + if (node.findAssignedValue() instanceof PyBinaryExpression binaryExpression) { + if (binaryExpression.isOperator("+")) { + List lhs = getStringListFromValue(binaryExpression.getLeftExpression()); + List rhs = getStringListFromValue(binaryExpression.getRightExpression()); + if (lhs != null && rhs != null) { + myResult = new ArrayList<>(lhs); + myResult.addAll(rhs); + } + } + } + else { + myResult = PyUtil.getStringListFromTargetExpression(node); + } + } + if (!myFoundDunderAll) { + List names = PyUtil.getStringListFromTargetExpression(node); + if (names != null) { + myDunderLike.put(node.getName(), names); + } + } + } + + @Nullable + private List getStringListFromValue(PyExpression expression) { + if (expression instanceof PyReferenceExpression refExpr && !refExpr.isQualified()) { + return myDunderLike.get(refExpr.getReferencedName()); + } + return PyUtil.strListValue(expression); + } + + @Override + public void visitPyAugAssignmentStatement(PyAugAssignmentStatement node) { + if (PyNames.ALL.equals(node.getTarget().getName())) { + myDynamic = true; + } + } + + @Override + @RequiredReadAction + public void visitPyCallExpression(PyCallExpression node) { + if (node.getCallee() instanceof PyQualifiedExpression qualifiedExpr) { + PyExpression qualifier = qualifiedExpr.getQualifier(); + if (qualifier != null && PyNames.ALL.equals(qualifier.getText())) { + // TODO handle append and extend with constant arguments here + myDynamic = true; + } + } + } + + @Nullable + List result() { + return myDynamic ? null : myResult; } - } } @Nullable - private List getStringListFromValue(PyExpression expression) { - if (expression instanceof PyReferenceExpression && !((PyReferenceExpression)expression).isQualified()) { - return myDunderLike.get(((PyReferenceExpression)expression).getReferencedName()); - } - return PyUtil.strListValue(expression); + public static List getStringListFromTargetExpression(String name, List attrs) { + for (PyTargetExpression attr : attrs) { + if (name.equals(attr.getName())) { + return PyUtil.getStringListFromTargetExpression(attr); + } + } + return null; } @Override - public void visitPyAugAssignmentStatement(PyAugAssignmentStatement node) { - if (PyNames.ALL.equals(node.getTarget().getName())) { - myDynamic = true; - } + @RequiredReadAction + public boolean hasImportFromFuture(FutureFeature feature) { + if (getStub() instanceof PyFileStub fileStub) { + return fileStub.getFutureFeatures().get(feature.ordinal()); + } + Boolean enabled = myFutureFeatures.get(feature); + if (enabled == null) { + enabled = calculateImportFromFuture(feature); + myFutureFeatures.put(feature, enabled); + // NOTE: ^^^ not synchronized. if two threads will try to modify this, both can only be expected to set the same value. + } + return enabled; } @Override - public void visitPyCallExpression(PyCallExpression node) { - final PyExpression callee = node.getCallee(); - if (callee instanceof PyQualifiedExpression) { - final PyExpression qualifier = ((PyQualifiedExpression)callee).getQualifier(); - if (qualifier != null && PyNames.ALL.equals(qualifier.getText())) { - // TODO handle append and extend with constant arguments here - myDynamic = true; + @RequiredReadAction + public String getDeprecationMessage() { + if (getStub() instanceof PyFileStub fileStub) { + return fileStub.getDeprecationMessage(); } - } + return extractDeprecationMessage(); } - @Nullable - List result() { - return myDynamic ? null : myResult; - } - } - - @Nullable - public static List getStringListFromTargetExpression(final String name, List attrs) { - for (PyTargetExpression attr : attrs) { - if (name.equals(attr.getName())) { - return PyUtil.getStringListFromTargetExpression(attr); - } - } - return null; - } - - @Override - public boolean hasImportFromFuture(FutureFeature feature) { - final StubElement stub = getStub(); - if (stub instanceof PyFileStub) { - return ((PyFileStub)stub).getFutureFeatures().get(feature.ordinal()); - } - Boolean enabled = myFutureFeatures.get(feature); - if (enabled == null) { - enabled = calculateImportFromFuture(feature); - myFutureFeatures.put(feature, enabled); - // NOTE: ^^^ not synchronized. if two threads will try to modify this, both can only be expected to set the same value. - } - return enabled; - } - - @Override - public String getDeprecationMessage() { - final StubElement stub = getStub(); - if (stub instanceof PyFileStub) { - return ((PyFileStub)stub).getDeprecationMessage(); - } - return extractDeprecationMessage(); - } - - @Override - public List getImportBlock() { - final List result = new ArrayList<>(); - final PsiElement firstChild = getFirstChild(); - final PyImportStatementBase firstImport; - if (firstChild instanceof PyImportStatementBase) { - firstImport = (PyImportStatementBase)firstChild; - } - else { - firstImport = PsiTreeUtil.getNextSiblingOfType(firstChild, PyImportStatementBase.class); - } - if (firstImport != null) { - result.add(firstImport); - PsiElement nextImport = PyPsiUtils.getNextNonCommentSibling(firstImport, true); - while (nextImport instanceof PyImportStatementBase) { - result.add((PyImportStatementBase)nextImport); - nextImport = PyPsiUtils.getNextNonCommentSibling(nextImport, true); - } - } - return result; - } - - public String extractDeprecationMessage() { - if (canHaveDeprecationMessage(getText())) { - return PyFunctionImpl.extractDeprecationMessage(getStatements()); - } - else { - return null; - } - } - - private static boolean canHaveDeprecationMessage(String text) { - return text.contains(PyNames.DEPRECATION_WARNING) || text.contains(PyNames.PENDING_DEPRECATION_WARNING); - } - - public boolean calculateImportFromFuture(FutureFeature feature) { - if (getText().contains(feature.toString())) { - final List fromImports = getFromImports(); - for (PyFromImportStatement fromImport : fromImports) { - if (fromImport.isFromFuture()) { - final PyImportElement[] pyImportElements = fromImport.getImportElements(); - for (PyImportElement element : pyImportElements) { - final QualifiedName qName = element.getImportedQName(); - if (qName != null && qName.matches(feature.toString())) { - return true; - } - } - } - } - } - return false; - } - - - @Override - public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { - if (myType == null) { - myType = new PyModuleType(this); - } - return myType; - } - - @Nullable - @Override - public String getDocStringValue() { - return DocStringUtil.getDocStringValue(this); - } - - @Nullable - @Override - public StructuredDocString getStructuredDocString() { - return DocStringUtil.getStructuredDocString(this); - } - - @Nullable - @Override - public PyStringLiteralExpression getDocStringExpression() { - return DocStringUtil.findDocStringExpression(this); - } - - @Override - public void subtreeChanged() { - super.subtreeChanged(); - ControlFlowCache.clear(this); - myDunderAllCalculated = false; - myFutureFeatures.clear(); // probably no need to synchronize - myExportedNameCache.clear(); - } - - @Override - public void delete() throws IncorrectOperationException { - String path = getVirtualFile().getPath(); - super.delete(); - PyUtil.deletePycFiles(path); - } - - @Override - public PsiElement setName(@Nonnull String name) throws IncorrectOperationException { - String path = getVirtualFile().getPath(); - final PsiElement newElement = super.setName(name); - PyUtil.deletePycFiles(path); - return newElement; - } - - private static class ArrayListThreadLocal extends ThreadLocal> { @Override - protected List initialValue() { - return new ArrayList<>(); - } - } - - @Override - public ItemPresentation getPresentation() { - return new ItemPresentation() { - @Override - public String getPresentableText() { - return getModuleName(PyFileImpl.this); - } - - @Override - public String getLocationString() { - final String name = getLocationName(); - return name != null ? "(" + name + ")" : null; - } - - @Override - public Image getIcon() { - if (PyUtil.isPackage(PyFileImpl.this)) { - return AllIcons.Modules.SourceRoot; - } - return IconDescriptorUpdaters.getIcon(PyFileImpl.this, 0); - } - - @Nonnull - private String getModuleName(@Nonnull PyFile file) { - if (PyUtil.isPackage(file)) { - final PsiDirectory dir = file.getContainingDirectory(); - if (dir != null) { - return dir.getName(); - } - } - return FileUtil.getNameWithoutExtension(file.getName()); - } - - @Nullable - private String getLocationName() { - final QualifiedName name = QualifiedNameFinder.findShortestImportableQName(PyFileImpl.this); - if (name != null) { - final QualifiedName prefix = name.removeTail(1); - if (prefix.getComponentCount() > 0) { - return prefix.toString(); - } - } - final String relativePath = getRelativeContainerPath(); - if (relativePath != null) { - return relativePath; - } - final PsiDirectory psiDirectory = getParent(); - if (psiDirectory != null) { - return psiDirectory.getVirtualFile().getPresentableUrl(); + @RequiredReadAction + public List getImportBlock() { + List result = new ArrayList<>(); + PsiElement firstChild = getFirstChild(); + PyImportStatementBase firstImport; + if (firstChild instanceof PyImportStatementBase importStmtBase) { + firstImport = importStmtBase; } - return null; - } - - @Nullable - private String getRelativeContainerPath() { - final PsiDirectory psiDirectory = getParent(); - if (psiDirectory != null) { - final VirtualFile virtualFile = getVirtualFile(); - if (virtualFile != null) { - final VirtualFile root = ProjectFileIndex.SERVICE.getInstance(getProject()).getContentRootForFile(virtualFile); - if (root != null) { - final VirtualFile parent = virtualFile.getParent(); - final VirtualFile rootParent = root.getParent(); - if (rootParent != null && parent != null) { - return VfsUtilCore.getRelativePath(parent, rootParent, File.separatorChar); - } - } - } + else { + firstImport = PsiTreeUtil.getNextSiblingOfType(firstChild, PyImportStatementBase.class); } - return null; - } - }; - } + if (firstImport != null) { + result.add(firstImport); + PsiElement nextImport = PyPsiUtils.getNextNonCommentSibling(firstImport, true); + while (nextImport instanceof PyImportStatementBase importStmtBase) { + result.add(importStmtBase); + nextImport = PyPsiUtils.getNextNonCommentSibling(importStmtBase, true); + } + } + return result; + } + + @RequiredReadAction + public String extractDeprecationMessage() { + if (canHaveDeprecationMessage(getText())) { + return PyFunctionImpl.extractDeprecationMessage(getStatements()); + } + else { + return null; + } + } + + private static boolean canHaveDeprecationMessage(String text) { + return text.contains(PyNames.DEPRECATION_WARNING) || text.contains(PyNames.PENDING_DEPRECATION_WARNING); + } + + @RequiredReadAction + public boolean calculateImportFromFuture(FutureFeature feature) { + if (getText().contains(feature.toString())) { + List fromImports = getFromImports(); + for (PyFromImportStatement fromImport : fromImports) { + if (fromImport.isFromFuture()) { + PyImportElement[] pyImportElements = fromImport.getImportElements(); + for (PyImportElement element : pyImportElements) { + QualifiedName qName = element.getImportedQName(); + if (qName != null && qName.matches(feature.toString())) { + return true; + } + } + } + } + } + return false; + } + + + @Override + public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { + if (myType == null) { + myType = new PyModuleType(this); + } + return myType; + } + + @Nullable + @Override + public String getDocStringValue() { + return DocStringUtil.getDocStringValue(this); + } + + @Nullable + @Override + public StructuredDocString getStructuredDocString() { + return DocStringUtil.getStructuredDocString(this); + } + + @Nullable + @Override + public PyStringLiteralExpression getDocStringExpression() { + return DocStringUtil.findDocStringExpression(this); + } + + @Override + @RequiredReadAction + public void subtreeChanged() { + super.subtreeChanged(); + ControlFlowCache.clear(this); + myDunderAllCalculated = false; + myFutureFeatures.clear(); // probably no need to synchronize + myExportedNameCache.clear(); + } + + @Override + @RequiredWriteAction + public void delete() throws IncorrectOperationException { + String path = getVirtualFile().getPath(); + super.delete(); + PyUtil.deletePycFiles(path); + } + + @Override + @RequiredWriteAction + public PsiElement setName(@Nonnull String name) throws IncorrectOperationException { + String path = getVirtualFile().getPath(); + PsiElement newElement = super.setName(name); + PyUtil.deletePycFiles(path); + return newElement; + } + + private static class ArrayListThreadLocal extends ThreadLocal> { + @Override + protected List initialValue() { + return new ArrayList<>(); + } + } + + @Override + public ItemPresentation getPresentation() { + return new ItemPresentation() { + @Override + @RequiredReadAction + public String getPresentableText() { + return getModuleName(PyFileImpl.this); + } + + @Override + @RequiredReadAction + public String getLocationString() { + String name = getLocationName(); + return name != null ? "(" + name + ")" : null; + } + + @Override + @RequiredReadAction + public Image getIcon() { + if (PyUtil.isPackage(PyFileImpl.this)) { + return PlatformIconGroup.modulesSourceroot(); + } + return IconDescriptorUpdaters.getIcon(PyFileImpl.this, 0); + } + + @Nonnull + @RequiredReadAction + private String getModuleName(@Nonnull PyFile file) { + if (PyUtil.isPackage(file)) { + PsiDirectory dir = file.getContainingDirectory(); + if (dir != null) { + return dir.getName(); + } + } + return FileUtil.getNameWithoutExtension(file.getName()); + } + + @Nullable + @RequiredReadAction + private String getLocationName() { + QualifiedName name = QualifiedNameFinder.findShortestImportableQName(PyFileImpl.this); + if (name != null) { + QualifiedName prefix = name.removeTail(1); + if (prefix.getComponentCount() > 0) { + return prefix.toString(); + } + } + String relativePath = getRelativeContainerPath(); + if (relativePath != null) { + return relativePath; + } + PsiDirectory psiDirectory = getParent(); + if (psiDirectory != null) { + return psiDirectory.getVirtualFile().getPresentableUrl(); + } + return null; + } + + @Nullable + @RequiredReadAction + private String getRelativeContainerPath() { + PsiDirectory psiDirectory = getParent(); + if (psiDirectory != null) { + VirtualFile virtualFile = getVirtualFile(); + if (virtualFile != null) { + VirtualFile root = ProjectFileIndex.SERVICE.getInstance(getProject()).getContentRootForFile(virtualFile); + if (root != null) { + VirtualFile parent = virtualFile.getParent(); + VirtualFile rootParent = root.getParent(); + if (rootParent != null && parent != null) { + return VfsUtilCore.getRelativePath(parent, rootParent, File.separatorChar); + } + } + } + } + return null; + } + }; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFunctionImpl.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFunctionImpl.java index 37832232..2ec0ea85 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFunctionImpl.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyFunctionImpl.java @@ -17,9 +17,7 @@ import com.jetbrains.python.PyNames; import com.jetbrains.python.PyTokenTypes; -import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.impl.PyElementTypes; -import com.jetbrains.python.impl.PythonIcons; import com.jetbrains.python.impl.codeInsight.PyTypingTypeProvider; import com.jetbrains.python.impl.codeInsight.controlflow.ControlFlowCache; import com.jetbrains.python.impl.codeInsight.dataflow.scope.ScopeUtil; @@ -40,12 +38,13 @@ import com.jetbrains.python.psi.types.PyClassType; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; +import consulo.annotation.access.RequiredWriteAction; import consulo.application.Application; import consulo.application.util.CachedValue; import consulo.application.util.CachedValueProvider; import consulo.application.util.CachedValuesManager; -import consulo.component.extension.Extensions; +import consulo.component.extension.ExtensionPoint; import consulo.content.scope.SearchScope; import consulo.language.ast.ASTNode; import consulo.language.psi.*; @@ -57,16 +56,17 @@ import consulo.language.util.IncorrectOperationException; import consulo.navigation.ItemPresentation; import consulo.platform.base.icon.PlatformIconGroup; +import consulo.python.impl.icon.PythonImplIconGroup; import consulo.ui.image.Image; import consulo.util.collection.ArrayUtil; import consulo.util.collection.JBIterable; import consulo.util.dataholder.Key; import consulo.util.lang.StringUtil; -import consulo.util.lang.ref.Ref; +import consulo.util.lang.ref.SimpleReference; import consulo.virtualFileSystem.VirtualFile; - import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.*; import static com.jetbrains.python.impl.psi.PyUtil.as; @@ -80,758 +80,794 @@ */ public class PyFunctionImpl extends PyBaseElementImpl implements PyFunction { - private static final Key>> ATTRIBUTES_KEY = Key.create("attributes"); + private static final Key>> ATTRIBUTES_KEY = Key.create("attributes"); + + public PyFunctionImpl(ASTNode astNode) { + super(astNode); + } + + public PyFunctionImpl(PyFunctionStub stub) { + this(stub, PyElementTypes.FUNCTION_DECLARATION); + } - public PyFunctionImpl(ASTNode astNode) { - super(astNode); - } + public PyFunctionImpl(PyFunctionStub stub, IStubElementType nodeType) { + super(stub, nodeType); + } - public PyFunctionImpl(final PyFunctionStub stub) { - this(stub, PyElementTypes.FUNCTION_DECLARATION); - } + private class CachedStructuredDocStringProvider implements CachedValueProvider { + @Nullable + @Override + public Result compute() { + PyFunctionImpl f = PyFunctionImpl.this; + return Result.create(DocStringUtil.getStructuredDocString(f), f); + } + } - public PyFunctionImpl(PyFunctionStub stub, IStubElementType nodeType) { - super(stub, nodeType); - } + private CachedStructuredDocStringProvider myCachedStructuredDocStringProvider = new CachedStructuredDocStringProvider(); - private class CachedStructuredDocStringProvider implements CachedValueProvider { @Nullable @Override - public Result compute() { - final PyFunctionImpl f = PyFunctionImpl.this; - return Result.create(DocStringUtil.getStructuredDocString(f), f); - } - } - - private CachedStructuredDocStringProvider myCachedStructuredDocStringProvider = new CachedStructuredDocStringProvider(); - - @Nullable - @Override - public String getName() { - final PyFunctionStub stub = getStub(); - if (stub != null) { - return stub.getName(); - } - - ASTNode node = getNameNode(); - return node != null ? node.getText() : null; - } - - public PsiElement getNameIdentifier() { - final ASTNode nameNode = getNameNode(); - return nameNode != null ? nameNode.getPsi() : null; - } - - public PsiElement setName(@Nonnull String name) throws IncorrectOperationException { - final ASTNode nameElement = PyUtil.createNewName(this, name); - final ASTNode nameNode = getNameNode(); - if (nameNode != null) { - getNode().replaceChild(nameNode, nameElement); - } - return this; - } - - public Image getIcon(int flags) { - PyPsiUtils.assertValid(this); - final Property property = getProperty(); - if (property != null) { - if (property.getGetter().valueOrNull() == this) { - return PythonIcons.Python.PropertyGetter; - } - if (property.getSetter().valueOrNull() == this) { - return PythonIcons.Python.PropertySetter; - } - if (property.getDeleter().valueOrNull() == this) { - return PythonIcons.Python.PropertyDeleter; - } - return PlatformIconGroup.nodesProperty(); - } - if (getContainingClass() != null) { - return AllIcons.Nodes.Method; - } - return AllIcons.Nodes.Function; - } - - @Nullable - public ASTNode getNameNode() { - return getNode().findChildByType(PyTokenTypes.IDENTIFIER); - } - - @Nonnull - public PyParameterList getParameterList() { - return getRequiredStubOrPsiChild(PyElementTypes.PARAMETER_LIST); - } - - @Override - @Nonnull - public PyStatementList getStatementList() { - final PyStatementList statementList = childToPsi(PyElementTypes.STATEMENT_LIST); - assert statementList != null : "Statement list missing for function " + getText(); - return statementList; - } - - public PyClass getContainingClass() { - final PyFunctionStub stub = getStub(); - if (stub != null) { - final StubElement parentStub = stub.getParentStub(); - if (parentStub instanceof PyClassStub) { - return ((PyClassStub)parentStub).getPsi(); - } - - return null; - } - - final PsiElement parent = PsiTreeUtil.getParentOfType(this, StubBasedPsiElement.class); - if (parent instanceof PyClass) { - return (PyClass)parent; - } - return null; - } - - @Nullable - public PyDecoratorList getDecoratorList() { - return getStubOrPsiChild(PyElementTypes.DECORATOR_LIST); // PsiTreeUtil.getChildOfType(this, PyDecoratorList.class); - } - - @Nullable - @Override - public PyType getReturnType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { - final PyType type = getReturnType(context); - return isAsync() && isAsyncAllowed() ? createCoroutineType(type) : type; - } - - @Nullable - private PyType getReturnType(@Nonnull TypeEvalContext context) { - for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { - final Ref returnTypeRef = typeProvider.getReturnType(this, context); - if (returnTypeRef != null) { - return derefType(returnTypeRef, typeProvider); - } - } - - if (context.allowReturnTypes(this)) { - final Ref yieldTypeRef = getYieldStatementType(context); - if (yieldTypeRef != null) { - return yieldTypeRef.get(); - } - return getReturnStatementType(context); - } - - return null; - } - - @Nullable - @Override - public PyType getCallType(@Nonnull TypeEvalContext context, @Nonnull PyCallSiteExpression callSite) { - for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { - final Ref typeRef = typeProvider.getCallType(this, callSite, context); - if (typeRef != null) { - return derefType(typeRef, typeProvider); - } - } - - final PyExpression receiver = PyTypeChecker.getReceiver(callSite, this); - final Map mapping = PyCallExpressionHelper.mapArguments(callSite, this, context); - return getCallType(receiver, mapping, context); - } - - @Nullable - private static PyType derefType(@Nonnull Ref typeRef, @Nonnull PyTypeProvider typeProvider) { - final PyType type = typeRef.get(); - if (type != null) { - type.assertValid(typeProvider.toString()); - } - return type; - } - - @Nullable - @Override - public PyType getCallType(@Nullable PyExpression receiver, - @Nonnull Map parameters, - @Nonnull TypeEvalContext context) { - return analyzeCallType(context.getReturnType(this), receiver, parameters, context); - } - - @Nullable - private PyType analyzeCallType(@Nullable PyType type, - @Nullable PyExpression receiver, - @Nonnull Map parameters, - @Nonnull TypeEvalContext context) { - if (PyTypeChecker.hasGenerics(type, context)) { - final Map substitutions = PyTypeChecker.unifyGenericCall(receiver, parameters, context); - if (substitutions != null) { - type = PyTypeChecker.substitute(type, substitutions, context); - } - else { - type = null; - } - } - if (receiver != null) { - type = replaceSelf(type, receiver, context); - } - if (type != null && isDynamicallyEvaluated(parameters.values(), context)) { - type = PyUnionType.createWeakType(type); - } - return type; - } - - @Override - public ItemPresentation getPresentation() { - return new PyElementPresentation(this) { - @Nullable - @Override - public String getPresentableText() { - return notNullize(getName(), PyNames.UNNAMED_ELEMENT) + getParameterList().getPresentableText(true); - } - - @Nullable - @Override - public String getLocationString() { - final PyClass containingClass = getContainingClass(); - if (containingClass != null) { - return "(" + containingClass.getName() + " in " + getPackageForFile(getContainingFile()) + ")"; - } - return super.getLocationString(); - } - }; - } - - @Nullable - private PyType replaceSelf(@Nullable PyType returnType, @Nullable PyExpression receiver, @Nonnull TypeEvalContext context) { - if (receiver != null) { - // TODO: Currently we substitute only simple subclass types, but we could handle union and collection types as well - if (returnType instanceof PyClassType) { - final PyClassType returnClassType = (PyClassType)returnType; - if (returnClassType.getPyClass() == getContainingClass()) { - final PyType receiverType = context.getType(receiver); - if (receiverType instanceof PyClassType && PyTypeChecker.match(returnType, receiverType, context)) { - return returnClassType.isDefinition() ? receiverType : ((PyClassType)receiverType).toInstance(); - } - } - } - } - return returnType; - } - - private static boolean isDynamicallyEvaluated(@Nonnull Collection parameters, @Nonnull TypeEvalContext context) { - for (PyNamedParameter parameter : parameters) { - final PyType type = context.getType(parameter); - if (type instanceof PyDynamicallyEvaluatedType) { - return true; - } - } - return false; - } - - @Nullable - private Ref getYieldStatementType(@Nonnull final TypeEvalContext context) { - Ref elementType = null; - final PyBuiltinCache cache = PyBuiltinCache.getInstance(this); - final PyStatementList statements = getStatementList(); - final Set types = new LinkedHashSet<>(); - statements.accept(new PyRecursiveElementVisitor() { - @Override - public void visitPyYieldExpression(PyYieldExpression node) { - final PyExpression expr = node.getExpression(); - final PyType type = expr != null ? context.getType(expr) : null; - if (node.isDelegating() && type instanceof PyCollectionType) { - final PyCollectionType collectionType = (PyCollectionType)type; - // TODO: Select the parameter types that matches T in Iterable[T] - final List elementTypes = collectionType.getElementTypes(context); - types.add(elementTypes.isEmpty() ? null : elementTypes.get(0)); + @RequiredReadAction + public String getName() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.getName(); + } + + ASTNode node = getNameNode(); + return node != null ? node.getText() : null; + } + + @Override + @RequiredReadAction + public PsiElement getNameIdentifier() { + ASTNode nameNode = getNameNode(); + return nameNode != null ? nameNode.getPsi() : null; + } + + @Override + @RequiredWriteAction + public PsiElement setName(@Nonnull String name) throws IncorrectOperationException { + ASTNode nameElement = PyUtil.createNewName(this, name); + ASTNode nameNode = getNameNode(); + if (nameNode != null) { + getNode().replaceChild(nameNode, nameElement); + } + return this; + } + + public Image getIcon(int flags) { + PyPsiUtils.assertValid(this); + Property property = getProperty(); + if (property != null) { + if (property.getGetter().valueOrNull() == this) { + return PythonImplIconGroup.pythonPropertygetter(); + } + if (property.getSetter().valueOrNull() == this) { + return PythonImplIconGroup.pythonPropertysetter(); + } + if (property.getDeleter().valueOrNull() == this) { + return PythonImplIconGroup.pythonPropertydeleter(); + } + return PlatformIconGroup.nodesProperty(); + } + if (getContainingClass() != null) { + return PlatformIconGroup.nodesMethod(); + } + return PlatformIconGroup.nodesFunction(); + } + + @Nullable + @Override + @RequiredReadAction + public ASTNode getNameNode() { + return getNode().findChildByType(PyTokenTypes.IDENTIFIER); + } + + @Nonnull + @Override + @RequiredReadAction + public PyParameterList getParameterList() { + return getRequiredStubOrPsiChild(PyElementTypes.PARAMETER_LIST); + } + + @Nonnull + @Override + @RequiredReadAction + public PyStatementList getStatementList() { + PyStatementList statementList = childToPsi(PyElementTypes.STATEMENT_LIST); + assert statementList != null : "Statement list missing for function " + getText(); + return statementList; + } + + @Override + public PyClass getContainingClass() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.getParentStub() instanceof PyClassStub classStub ? classStub.getPsi() : null; + } + + if (PsiTreeUtil.getParentOfType(this, StubBasedPsiElement.class) instanceof PyClass pyClass) { + return pyClass; + } + return null; + } + + @Nullable + @Override + @RequiredReadAction + public PyDecoratorList getDecoratorList() { + return getStubOrPsiChild(PyElementTypes.DECORATOR_LIST); // PsiTreeUtil.getChildOfType(this, PyDecoratorList.class); + } + + @Nullable + @Override + @RequiredReadAction + public PyType getReturnType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { + PyType type = getReturnType(context); + return isAsync() && isAsyncAllowed() ? createCoroutineType(type) : type; + } + + @Nullable + @RequiredReadAction + private PyType getReturnType(@Nonnull TypeEvalContext context) { + for (PyTypeProvider typeProvider : Application.get().getExtensionList(PyTypeProvider.class)) { + SimpleReference returnTypeRef = typeProvider.getReturnType(this, context); + if (returnTypeRef != null) { + return derefType(returnTypeRef, typeProvider); + } + } + + if (context.allowReturnTypes(this)) { + SimpleReference yieldTypeRef = getYieldStatementType(context); + if (yieldTypeRef != null) { + return yieldTypeRef.get(); + } + return getReturnStatementType(context); + } + + return null; + } + + @Nullable + @Override + public PyType getCallType(@Nonnull TypeEvalContext context, @Nonnull PyCallSiteExpression callSite) { + for (PyTypeProvider typeProvider : callSite.getApplication().getExtensionList(PyTypeProvider.class)) { + SimpleReference typeRef = typeProvider.getCallType(this, callSite, context); + if (typeRef != null) { + return derefType(typeRef, typeProvider); + } + } + + PyExpression receiver = PyTypeChecker.getReceiver(callSite, this); + Map mapping = PyCallExpressionHelper.mapArguments(callSite, this, context); + return getCallType(receiver, mapping, context); + } + + @Nullable + private static PyType derefType(@Nonnull SimpleReference typeRef, @Nonnull PyTypeProvider typeProvider) { + PyType type = typeRef.get(); + if (type != null) { + type.assertValid(typeProvider.toString()); + } + return type; + } + + @Nullable + @Override + public PyType getCallType( + @Nullable PyExpression receiver, + @Nonnull Map parameters, + @Nonnull TypeEvalContext context + ) { + return analyzeCallType(context.getReturnType(this), receiver, parameters, context); + } + + @Nullable + private PyType analyzeCallType( + @Nullable PyType type, + @Nullable PyExpression receiver, + @Nonnull Map parameters, + @Nonnull TypeEvalContext context + ) { + if (PyTypeChecker.hasGenerics(type, context)) { + Map substitutions = PyTypeChecker.unifyGenericCall(receiver, parameters, context); + if (substitutions != null) { + type = PyTypeChecker.substitute(type, substitutions, context); + } + else { + type = null; + } + } + if (receiver != null) { + type = replaceSelf(type, receiver, context); + } + if (type != null && isDynamicallyEvaluated(parameters.values(), context)) { + type = PyUnionType.createWeakType(type); } - else { - types.add(type); - } - } - - @Override - public void visitPyFunction(PyFunction node) { - // Ignore nested functions - } - }); - final int n = types.size(); - if (n == 1) { - elementType = Ref.create(types.iterator().next()); - } - else if (n > 0) { - elementType = Ref.create(PyUnionType.union(types)); - } - if (elementType != null) { - final PyClass generator = cache.getClass(PyNames.FAKE_GENERATOR); - if (generator != null) { - final List parameters = Arrays.asList(elementType.get(), null, getReturnStatementType(context)); - return Ref.create(new PyCollectionTypeImpl(generator, false, parameters)); - } - } - if (!types.isEmpty()) { - return Ref.create(null); - } - return null; - } - - @Nullable - public PyType getReturnStatementType(TypeEvalContext typeEvalContext) { - final ReturnVisitor visitor = new ReturnVisitor(this, typeEvalContext); - final PyStatementList statements = getStatementList(); - statements.accept(visitor); - if (isGeneratedStub() && !visitor.myHasReturns) { - if (PyNames.INIT.equals(getName())) { - return PyNoneType.INSTANCE; - } - return null; - } - return visitor.result(); - } - - @Nullable - private PyType createCoroutineType(@Nullable PyType returnType) { - final PyBuiltinCache cache = PyBuiltinCache.getInstance(this); - if (returnType instanceof PyClassLikeType && PyNames.FAKE_COROUTINE.equals(((PyClassLikeType)returnType).getClassQName())) { - return returnType; - } - final PyClass generator = cache.getClass(PyNames.FAKE_COROUTINE); - return generator != null ? new PyCollectionTypeImpl(generator, false, Collections.singletonList(returnType)) : null; - } - - public PyFunction asMethod() { - if (getContainingClass() != null) { - return this; - } - else { - return null; - } - } - - @Nullable - @Override - public PyType getReturnTypeFromDocString() { - final String typeName = extractReturnType(); - return typeName != null ? PyTypeParser.getTypeByName(this, typeName) : null; - } - - @Nullable - @Override - public String getDeprecationMessage() { - PyFunctionStub stub = getStub(); - if (stub != null) { - return stub.getDeprecationMessage(); - } - return extractDeprecationMessage(); - } - - @Nullable - public String extractDeprecationMessage() { - PyStatementList statementList = getStatementList(); - return extractDeprecationMessage(Arrays.asList(statementList.getStatements())); - } - - @Override - public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { - for (PyTypeProvider provider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { - final PyType type = provider.getCallableType(this, context); - if (type != null) { return type; - } - } - final PyFunctionTypeImpl type = new PyFunctionTypeImpl(this); - if (PyKnownDecoratorUtil.hasUnknownDecorator(this, context) && getProperty() == null) { - return PyUnionType.createWeakType(type); - } - return type; - } - - @Nullable - public static String extractDeprecationMessage(List statements) { - for (PyStatement statement : statements) { - if (statement instanceof PyExpressionStatement) { - PyExpressionStatement expressionStatement = (PyExpressionStatement)statement; - if (expressionStatement.getExpression() instanceof PyCallExpression) { - PyCallExpression callExpression = (PyCallExpression)expressionStatement.getExpression(); - if (callExpression.isCalleeText(PyNames.WARN)) { - PyReferenceExpression warningClass = callExpression.getArgument(1, PyReferenceExpression.class); - if (warningClass != null && (PyNames.DEPRECATION_WARNING.equals(warningClass.getReferencedName()) || PyNames.PENDING_DEPRECATION_WARNING - .equals(warningClass.getReferencedName - ()))) { - return PyPsiUtils.strValue(callExpression.getArguments()[0]); + } + + @Override + public ItemPresentation getPresentation() { + return new PyElementPresentation(this) { + @Nullable + @Override + @RequiredReadAction + public String getPresentableText() { + return notNullize(getName(), PyNames.UNNAMED_ELEMENT) + getParameterList().getPresentableText(true); + } + + @Nullable + @Override + @RequiredReadAction + public String getLocationString() { + PyClass containingClass = getContainingClass(); + if (containingClass != null) { + return "(" + containingClass.getName() + " in " + getPackageForFile(getContainingFile()) + ")"; + } + return super.getLocationString(); + } + }; + } + + @Nullable + private PyType replaceSelf(@Nullable PyType returnType, @Nullable PyExpression receiver, @Nonnull TypeEvalContext context) { + // TODO: Currently we substitute only simple subclass types, but we could handle union and collection types as well + if (receiver != null + && returnType instanceof PyClassType returnClassType + && returnClassType.getPyClass() == getContainingClass() + && context.getType(receiver) instanceof PyClassType receiverClassType + && PyTypeChecker.match(returnClassType, receiverClassType, context)) { + return returnClassType.isDefinition() ? receiverClassType : receiverClassType.toInstance(); + } + return returnType; + } + + private static boolean isDynamicallyEvaluated(@Nonnull Collection parameters, @Nonnull TypeEvalContext context) { + for (PyNamedParameter parameter : parameters) { + if (context.getType(parameter) instanceof PyDynamicallyEvaluatedType) { + return true; + } + } + return false; + } + + @Nullable + @RequiredReadAction + private SimpleReference getYieldStatementType(@Nonnull final TypeEvalContext context) { + SimpleReference elementType = null; + PyBuiltinCache cache = PyBuiltinCache.getInstance(this); + PyStatementList statements = getStatementList(); + final Set types = new LinkedHashSet<>(); + statements.accept(new PyRecursiveElementVisitor() { + @Override + public void visitPyYieldExpression(PyYieldExpression node) { + PyExpression expr = node.getExpression(); + PyType type = expr != null ? context.getType(expr) : null; + if (node.isDelegating() && type instanceof PyCollectionType collectionType) { + // TODO: Select the parameter types that matches T in Iterable[T] + List elementTypes = collectionType.getElementTypes(context); + types.add(elementTypes.isEmpty() ? null : elementTypes.get(0)); + } + else { + types.add(type); + } + } + + @Override + public void visitPyFunction(PyFunction node) { + // Ignore nested functions + } + }); + int n = types.size(); + if (n == 1) { + elementType = SimpleReference.create(types.iterator().next()); + } + else if (n > 0) { + elementType = SimpleReference.create(PyUnionType.union(types)); + } + if (elementType != null) { + PyClass generator = cache.getClass(PyNames.FAKE_GENERATOR); + if (generator != null) { + List parameters = Arrays.asList(elementType.get(), null, getReturnStatementType(context)); + return SimpleReference.create(new PyCollectionTypeImpl(generator, false, parameters)); + } + } + if (!types.isEmpty()) { + return SimpleReference.create(null); + } + return null; + } + + @Nullable + @Override + @RequiredReadAction + public PyType getReturnStatementType(TypeEvalContext typeEvalContext) { + ReturnVisitor visitor = new ReturnVisitor(this, typeEvalContext); + PyStatementList statements = getStatementList(); + statements.accept(visitor); + if (isGeneratedStub() && !visitor.myHasReturns) { + if (PyNames.INIT.equals(getName())) { + return PyNoneType.INSTANCE; } - } - } - } - } - return null; - } - - @Override - public String getDocStringValue() { - final PyFunctionStub stub = getStub(); - if (stub != null) { - return stub.getDocString(); - } - return DocStringUtil.getDocStringValue(this); - } - - @Nullable - @Override - public StructuredDocString getStructuredDocString() { - return LanguageCachedValueUtil.getCachedValue(this, myCachedStructuredDocStringProvider); - } - - private boolean isGeneratedStub() { - VirtualFile vFile = getContainingFile().getVirtualFile(); - if (vFile != null) { - vFile = vFile.getParent(); - if (vFile != null) { - vFile = vFile.getParent(); - if (vFile != null && vFile.getName().equals(PythonSdkType.SKELETON_DIR_NAME)) { - return true; - } - } - } - return false; - } - - @Nullable - private String extractReturnType() { - final String ARROW = "->"; - final StructuredDocString structuredDocString = getStructuredDocString(); - if (structuredDocString != null) { - return structuredDocString.getReturnType(); - } - final String docString = getDocStringValue(); - if (docString != null && docString.contains(ARROW)) { - final List lines = StringUtil.split(docString, "\n"); - while (lines.size() > 0 && lines.get(0).trim().length() == 0) { - lines.remove(0); - } - if (lines.size() > 1 && lines.get(1).trim().length() == 0) { - String firstLine = lines.get(0); - int pos = firstLine.lastIndexOf(ARROW); - if (pos >= 0) { - return firstLine.substring(pos + 2).trim(); - } - } - } - return null; - } - - private static class ReturnVisitor extends PyRecursiveElementVisitor { - private final PyFunction myFunction; - private final TypeEvalContext myContext; - private PyType myResult = null; - private boolean myHasReturns = false; - private boolean myHasRaises = false; - - public ReturnVisitor(PyFunction function, final TypeEvalContext context) { - myFunction = function; - myContext = context; - } - - @Override - public void visitPyReturnStatement(PyReturnStatement node) { - if (ScopeUtil.getScopeOwner(node) == myFunction) { - final PyExpression expr = node.getExpression(); - PyType returnType; - returnType = expr == null ? PyNoneType.INSTANCE : myContext.getType(expr); - if (!myHasReturns) { - myResult = returnType; - myHasReturns = true; + return null; + } + return visitor.result(); + } + + @Nullable + private PyType createCoroutineType(@Nullable PyType returnType) { + PyBuiltinCache cache = PyBuiltinCache.getInstance(this); + if (returnType instanceof PyClassLikeType classLikeType && PyNames.FAKE_COROUTINE.equals(classLikeType.getClassQName())) { + return classLikeType; + } + PyClass generator = cache.getClass(PyNames.FAKE_COROUTINE); + return generator != null ? new PyCollectionTypeImpl(generator, false, Collections.singletonList(returnType)) : null; + } + + @Override + public PyFunction asMethod() { + if (getContainingClass() != null) { + return this; } else { - myResult = PyUnionType.union(myResult, returnType); + return null; } - } } + @Nullable + @Override + public PyType getReturnTypeFromDocString() { + String typeName = extractReturnType(); + return typeName != null ? PyTypeParser.getTypeByName(this, typeName) : null; + } + + @Nullable @Override - public void visitPyRaiseStatement(PyRaiseStatement node) { - myHasRaises = true; + @RequiredReadAction + public String getDeprecationMessage() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.getDeprecationMessage(); + } + return extractDeprecationMessage(); } @Nullable - PyType result() { - return myHasReturns || myHasRaises ? myResult : PyNoneType.INSTANCE; - } - } - - @Override - protected void acceptPyVisitor(PyElementVisitor pyVisitor) { - pyVisitor.visitPyFunction(this); - } - - public int getTextOffset() { - final ASTNode name = getNameNode(); - return name != null ? name.getStartOffset() : super.getTextOffset(); - } - - public PyStringLiteralExpression getDocStringExpression() { - final PyStatementList stmtList = getStatementList(); - return DocStringUtil.findDocStringExpression(stmtList); - } - - @Override - public String toString() { - return super.toString() + "('" + getName() + "')"; - } - - public void subtreeChanged() { - super.subtreeChanged(); - ControlFlowCache.clear(this); - } - - public Property getProperty() { - final PyClass containingClass = getContainingClass(); - if (containingClass != null) { - return containingClass.findPropertyByCallable(this); - } - return null; - } - - @Override - public PyAnnotation getAnnotation() { - return getStubOrPsiChild(PyElementTypes.ANNOTATION); - } - - @Nullable - @Override - public PsiComment getTypeComment() { - final PsiComment inlineComment = PyUtil.getCommentOnHeaderLine(this); - if (inlineComment != null && PyTypingTypeProvider.getTypeCommentValue(inlineComment.getText()) != null) { - return inlineComment; - } - - final PyStatementList statements = getStatementList(); - if (statements.getStatements().length != 0) { - final PsiComment comment = as(statements.getFirstChild(), PsiComment.class); - if (comment != null && PyTypingTypeProvider.getTypeCommentValue(comment.getText()) != null) { - return comment; - } - } - return null; - } - - @Nullable - @Override - public String getTypeCommentAnnotation() { - final PyFunctionStub stub = getStub(); - if (stub != null) { - return stub.getTypeComment(); - } - final PsiComment comment = getTypeComment(); - if (comment != null) { - return PyTypingTypeProvider.getTypeCommentValue(comment.getText()); - } - return null; - } - - @Nonnull - @Override - public SearchScope getUseScope() { - final ScopeOwner scopeOwner = ScopeUtil.getScopeOwner(this); - if (scopeOwner instanceof PyFunction) { - return new LocalSearchScope(scopeOwner); - } - return super.getUseScope(); - } - - /** - * Looks for two standard decorators to a function, or a wrapping assignment that closely follows it. - * - * @return a flag describing what was detected. - */ - @Nullable - public Modifier getModifier() { - final String deconame = getClassOrStaticMethodDecorator(); - if (PyNames.CLASSMETHOD.equals(deconame)) { - return CLASSMETHOD; - } - else if (PyNames.STATICMETHOD.equals(deconame)) { - return STATICMETHOD; - } - - // implicit staticmethod __new__ - final PyClass cls = getContainingClass(); - if (cls != null && PyNames.NEW.equals(getName()) && cls.isNewStyleClass(null)) { - return STATICMETHOD; - } - - final PyFunctionStub stub = getStub(); - if (stub != null) { - return getModifierFromStub(stub); - } - - final String funcName = getName(); - if (funcName != null) { - PyAssignmentStatement currentAssignment = PsiTreeUtil.getNextSiblingOfType(this, PyAssignmentStatement.class); - while (currentAssignment != null) { - final String modifier = currentAssignment.getTargetsToValuesMapping() - .stream() - .filter(pair -> pair.getFirst() instanceof PyTargetExpression && funcName.equals(pair.getFirst() - .getName - ())) - .filter(pair -> pair.getSecond() instanceof PyCallExpression) - .map(pair -> interpretAsModifierWrappingCall((PyCallExpression)pair.getSecond(), - this)) - .filter(interpreted -> interpreted != null && interpreted.getSecond() == this) - .map(interpreted -> interpreted.getFirst()) - .filter(wrapperName -> PyNames.CLASSMETHOD - .equals(wrapperName) || PyNames.STATICMETHOD.equals(wrapperName)) - .findAny() - .orElse(null); - - if (PyNames.CLASSMETHOD.equals(modifier)) { - return CLASSMETHOD; - } - else if (PyNames.STATICMETHOD.equals(modifier)) { - return STATICMETHOD; - } - - currentAssignment = PsiTreeUtil.getNextSiblingOfType(currentAssignment, PyAssignmentStatement.class); - } - } - - return null; - } - - @Override - public boolean isAsync() { - final PyFunctionStub stub = getStub(); - if (stub != null) { - return stub.isAsync(); - } - return getNode().findChildByType(PyTokenTypes.ASYNC_KEYWORD) != null; - } - - @Override - public boolean isAsyncAllowed() { - final LanguageLevel languageLevel = LanguageLevel.forElement(this); - final String functionName = getName(); - - return languageLevel.isAtLeast(LanguageLevel.PYTHON35) && (functionName == null || - ArrayUtil.contains(functionName, PyNames.AITER, PyNames.ANEXT, PyNames.AENTER, PyNames.AEXIT) || - !PyNames.getBuiltinMethods(languageLevel).containsKey(functionName)); - } - - @Nullable - private static Modifier getModifierFromStub(@Nonnull PyFunctionStub stub) { - final Optional> siblingsStubsOptional = - Optional.of(stub).map(StubElement::getParentStub).map(StubElement::getChildrenStubs); - - if (siblingsStubsOptional.isPresent()) { - return JBIterable.from(siblingsStubsOptional.get()) - .skipWhile(siblingStub -> !stub.equals(siblingStub)) - .transform(nextSiblingStub -> as(nextSiblingStub, - PyTargetExpressionStub.class)) - .filter(Objects::nonNull) - .filter(nextSiblingStub -> nextSiblingStub.getInitializerType() == PyTargetExpressionStub.InitializerType.CallExpression) - .transform(PyTargetExpressionStub::getInitializer) - .transform(initializerName -> { - if (initializerName.matches(PyNames.CLASSMETHOD)) { - return CLASSMETHOD; - } - else if (initializerName.matches(PyNames.STATICMETHOD)) { - return STATICMETHOD; - } - else { - return null; - } - }) - .find(Objects::nonNull); - } - - return null; - } - - /** - * When a function is decorated many decorators, finds the deepest builtin decorator: - *
-   * @foo
-   * @classmethod # <-- that's it
-   * @bar
-   * def moo(cls):
-   *   pass
-   * 
- * - * @return name of the built-in decorator, or null (even if there are non-built-in decorators). - */ - @Nullable - private String getClassOrStaticMethodDecorator() { - PyDecoratorList decolist = getDecoratorList(); - if (decolist != null) { - PyDecorator[] decos = decolist.getDecorators(); - if (decos.length > 0) { - for (int i = decos.length - 1; i >= 0; i -= 1) { - PyDecorator deco = decos[i]; - String deconame = deco.getName(); - if (PyNames.CLASSMETHOD.equals(deconame) || PyNames.STATICMETHOD.equals(deconame)) { - return deconame; - } - for (PyKnownDecoratorProvider provider : Application.get().getExtensionPoint(PyKnownDecoratorProvider.class)) { - String name = provider.toKnownDecorator(deconame); - if (name != null) { - return name; + @RequiredReadAction + public String extractDeprecationMessage() { + PyStatementList statementList = getStatementList(); + return extractDeprecationMessage(Arrays.asList(statementList.getStatements())); + } + + @Override + public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { + PyType callableType = Application.get().getExtensionPoint(PyTypeProvider.class) + .computeSafeIfAny(provider -> provider.getCallableType(this, context)); + if (callableType != null) { + return callableType; + } + PyFunctionTypeImpl type = new PyFunctionTypeImpl(this); + if (PyKnownDecoratorUtil.hasUnknownDecorator(this, context) && getProperty() == null) { + return PyUnionType.createWeakType(type); + } + return type; + } + + private static final Set DEPRECATION_WARNINGS = Set.of(PyNames.DEPRECATION_WARNING, PyNames.PENDING_DEPRECATION_WARNING); + @Nullable + public static String extractDeprecationMessage(List statements) { + for (PyStatement statement : statements) { + if (statement instanceof PyExpressionStatement exprStmt + && exprStmt.getExpression() instanceof PyCallExpression callExpr + && callExpr.isCalleeText(PyNames.WARN)) { + PyReferenceExpression warningClass = callExpr.getArgument(1, PyReferenceExpression.class); + if (warningClass != null && DEPRECATION_WARNINGS.contains(warningClass.getReferencedName())) { + return PyPsiUtils.strValue(callExpr.getArguments()[0]); + } + } + } + return null; + } + + @Override + public String getDocStringValue() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.getDocString(); + } + return DocStringUtil.getDocStringValue(this); + } + + @Nullable + @Override + public StructuredDocString getStructuredDocString() { + return LanguageCachedValueUtil.getCachedValue(this, myCachedStructuredDocStringProvider); + } + + private boolean isGeneratedStub() { + VirtualFile vFile = getContainingFile().getVirtualFile(); + if (vFile != null) { + vFile = vFile.getParent(); + if (vFile != null) { + vFile = vFile.getParent(); + if (vFile != null && vFile.getName().equals(PythonSdkType.SKELETON_DIR_NAME)) { + return true; + } + } + } + return false; + } + + @Nullable + private String extractReturnType() { + String ARROW = "->"; + StructuredDocString structuredDocString = getStructuredDocString(); + if (structuredDocString != null) { + return structuredDocString.getReturnType(); + } + String docString = getDocStringValue(); + if (docString != null && docString.contains(ARROW)) { + List lines = StringUtil.split(docString, "\n"); + while (lines.size() > 0 && lines.get(0).trim().length() == 0) { + lines.remove(0); + } + if (lines.size() > 1 && lines.get(1).trim().length() == 0) { + String firstLine = lines.get(0); + int pos = firstLine.lastIndexOf(ARROW); + if (pos >= 0) { + return firstLine.substring(pos + 2).trim(); + } } - } } - } + return null; + } + + private static class ReturnVisitor extends PyRecursiveElementVisitor { + private final PyFunction myFunction; + private final TypeEvalContext myContext; + private PyType myResult = null; + private boolean myHasReturns = false; + private boolean myHasRaises = false; + + public ReturnVisitor(PyFunction function, TypeEvalContext context) { + myFunction = function; + myContext = context; + } + + @Override + public void visitPyReturnStatement(PyReturnStatement node) { + if (ScopeUtil.getScopeOwner(node) == myFunction) { + PyExpression expr = node.getExpression(); + PyType returnType; + returnType = expr == null ? PyNoneType.INSTANCE : myContext.getType(expr); + if (!myHasReturns) { + myResult = returnType; + myHasReturns = true; + } + else { + myResult = PyUnionType.union(myResult, returnType); + } + } + } + + @Override + public void visitPyRaiseStatement(PyRaiseStatement node) { + myHasRaises = true; + } + + @Nullable + PyType result() { + return myHasReturns || myHasRaises ? myResult : PyNoneType.INSTANCE; + } + } + + @Override + protected void acceptPyVisitor(PyElementVisitor pyVisitor) { + pyVisitor.visitPyFunction(this); + } + + @Override + @RequiredReadAction + public int getTextOffset() { + ASTNode name = getNameNode(); + return name != null ? name.getStartOffset() : super.getTextOffset(); + } + + @Override + @RequiredReadAction + public PyStringLiteralExpression getDocStringExpression() { + PyStatementList stmtList = getStatementList(); + return DocStringUtil.findDocStringExpression(stmtList); + } + + @Override + @RequiredReadAction + public String toString() { + return super.toString() + "('" + getName() + "')"; + } + + @Override + public void subtreeChanged() { + super.subtreeChanged(); + ControlFlowCache.clear(this); + } + + @Override + public Property getProperty() { + PyClass containingClass = getContainingClass(); + if (containingClass != null) { + return containingClass.findPropertyByCallable(this); + } + return null; + } + + @Override + @RequiredReadAction + public PyAnnotation getAnnotation() { + return getStubOrPsiChild(PyElementTypes.ANNOTATION); } - return null; - } - @Nullable - @Override - public String getQualifiedName() { - return QualifiedNameFinder.getQualifiedName(this); - } + @Nullable + @Override + @RequiredReadAction + public PsiComment getTypeComment() { + PsiComment inlineComment = PyUtil.getCommentOnHeaderLine(this); + if (inlineComment != null && PyTypingTypeProvider.getTypeCommentValue(inlineComment.getText()) != null) { + return inlineComment; + } + + PyStatementList statements = getStatementList(); + if (statements.getStatements().length != 0) { + PsiComment comment = as(statements.getFirstChild(), PsiComment.class); + if (comment != null && PyTypingTypeProvider.getTypeCommentValue(comment.getText()) != null) { + return comment; + } + } + return null; + } + + @Nullable + @Override + @RequiredReadAction + public String getTypeCommentAnnotation() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.getTypeComment(); + } + PsiComment comment = getTypeComment(); + if (comment != null) { + return PyTypingTypeProvider.getTypeCommentValue(comment.getText()); + } + return null; + } + + @Nonnull + @Override + public SearchScope getUseScope() { + if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) { + return new LocalSearchScope(function); + } + return super.getUseScope(); + } - @Nonnull - @Override - public List findAttributes() { /** - * TODO: This method if insanely heavy since it unstubs foreign files. - * Need to save stubs and use them somehow. + * Looks for two standard decorators to a function, or a wrapping assignment that closely follows it. * + * @return a flag describing what was detected. */ - return CachedValuesManager.getManager(getProject()).getCachedValue(this, ATTRIBUTES_KEY, () -> { - final List result = findAttributesStatic(this); - return CachedValueProvider.Result.create(result, PsiModificationTracker.MODIFICATION_COUNT); - }, false); - } - - /** - * @param self should be this - */ - @Nonnull - private static List findAttributesStatic(@Nonnull final PsiElement self) { - final List result = new ArrayList<>(); - for (final PyAssignmentStatement statement : new PsiQuery(self).siblings(PyAssignmentStatement.class).getElements()) { - for (final PyQualifiedExpression targetExpression : new PsiQuery(statement.getTargets()).filter(PyQualifiedExpression.class) - .getElements()) { - final PyExpression qualifier = targetExpression.getQualifier(); - if (qualifier == null) { - continue; - } - final PsiReference qualifierReference = qualifier.getReference(); - if (qualifierReference == null) { - continue; - } - if (qualifierReference.isReferenceTo(self)) { - result.add(statement); - } - } - } - return result; - } - - @Nonnull - @Override - public ProtectionLevel getProtectionLevel() { - final int underscoreLevels = PyUtil.getInitialUnderscores(getName()); - for (final ProtectionLevel level : ProtectionLevel.values()) { - if (level.getUnderscoreLevel() == underscoreLevels) { - return level; - } - } - return ProtectionLevel.PRIVATE; - } + @Nullable + @Override + @RequiredReadAction + public Modifier getModifier() { + String decoName = getClassOrStaticMethodDecorator(); + if (PyNames.CLASSMETHOD.equals(decoName)) { + return CLASSMETHOD; + } + else if (PyNames.STATICMETHOD.equals(decoName)) { + return STATICMETHOD; + } + + // implicit staticmethod __new__ + PyClass cls = getContainingClass(); + if (cls != null && PyNames.NEW.equals(getName()) && cls.isNewStyleClass(null)) { + return STATICMETHOD; + } + + PyFunctionStub stub = getStub(); + if (stub != null) { + return getModifierFromStub(stub); + } + + String funcName = getName(); + if (funcName != null) { + PyAssignmentStatement currentAssignment = PsiTreeUtil.getNextSiblingOfType(this, PyAssignmentStatement.class); + while (currentAssignment != null) { + String modifier = currentAssignment.getTargetsToValuesMapping() + .stream() + .filter( + pair -> pair.getFirst() instanceof PyTargetExpression targetExpr + && funcName.equals(targetExpr.getName()) + ) + .filter(pair -> pair.getSecond() instanceof PyCallExpression) + .map(pair -> interpretAsModifierWrappingCall( + (PyCallExpression) pair.getSecond(), + this + )) + .filter(interpreted -> interpreted != null && interpreted.getSecond() == this) + .map(interpreted -> interpreted.getFirst()) + .filter(wrapperName -> PyNames.CLASSMETHOD.equals(wrapperName) || PyNames.STATICMETHOD.equals(wrapperName)) + .findAny() + .orElse(null); + + if (PyNames.CLASSMETHOD.equals(modifier)) { + return CLASSMETHOD; + } + else if (PyNames.STATICMETHOD.equals(modifier)) { + return STATICMETHOD; + } + + currentAssignment = PsiTreeUtil.getNextSiblingOfType(currentAssignment, PyAssignmentStatement.class); + } + } + + return null; + } + + @Override + @RequiredReadAction + public boolean isAsync() { + PyFunctionStub stub = getStub(); + if (stub != null) { + return stub.isAsync(); + } + return getNode().findChildByType(PyTokenTypes.ASYNC_KEYWORD) != null; + } + + @Override + @RequiredReadAction + public boolean isAsyncAllowed() { + LanguageLevel languageLevel = LanguageLevel.forElement(this); + String functionName = getName(); + + return languageLevel.isAtLeast(LanguageLevel.PYTHON35) + && (functionName == null + || ArrayUtil.contains(functionName, PyNames.AITER, PyNames.ANEXT, PyNames.AENTER, PyNames.AEXIT) + || !PyNames.getBuiltinMethods(languageLevel).containsKey(functionName)); + } + + @Nullable + private static Modifier getModifierFromStub(@Nonnull PyFunctionStub stub) { + Optional> siblingsStubsOptional = Optional.of(stub) + .map(StubElement::getParentStub) + .map(StubElement::getChildrenStubs); + + if (siblingsStubsOptional.isPresent()) { + return JBIterable.from(siblingsStubsOptional.get()) + .skipWhile(siblingStub -> !stub.equals(siblingStub)) + .transform(nextSiblingStub -> as( + nextSiblingStub, + PyTargetExpressionStub.class + )) + .filter(Objects::nonNull) + .filter(nextSiblingStub -> nextSiblingStub.getInitializerType() == PyTargetExpressionStub.InitializerType.CallExpression) + .transform(PyTargetExpressionStub::getInitializer) + .transform(initializerName -> { + if (initializerName.matches(PyNames.CLASSMETHOD)) { + return CLASSMETHOD; + } + else if (initializerName.matches(PyNames.STATICMETHOD)) { + return STATICMETHOD; + } + else { + return null; + } + }) + .find(Objects::nonNull); + } + + return null; + } + + /** + * When a function is decorated many decorators, finds the deepest builtin decorator: + *
+     * @foo
+     * @classmethod # <-- that's it
+     * @bar
+     * def moo(cls):
+     *   pass
+     * 
+ * + * @return name of the built-in decorator, or null (even if there are non-built-in decorators). + */ + @Nullable + @RequiredReadAction + private String getClassOrStaticMethodDecorator() { + PyDecoratorList decoList = getDecoratorList(); + if (decoList != null) { + PyDecorator[] decos = decoList.getDecorators(); + if (decos.length > 0) { + ExtensionPoint knownDecoratorProviders = + decoList.getApplication().getExtensionPoint(PyKnownDecoratorProvider.class); + for (int i = decos.length - 1; i >= 0; i -= 1) { + PyDecorator deco = decos[i]; + String decoName = deco.getName(); + if (PyNames.CLASSMETHOD.equals(decoName) || PyNames.STATICMETHOD.equals(decoName)) { + return decoName; + } + String name = knownDecoratorProviders.computeSafeIfAny(provider -> provider.toKnownDecorator(decoName)); + if (name != null) { + return name; + } + } + } + } + return null; + } + + @Nullable + @Override + public String getQualifiedName() { + return QualifiedNameFinder.getQualifiedName(this); + } + + @Nonnull + @Override + public List findAttributes() { + /** + * TODO: This method if insanely heavy since it unstubs foreign files. + * Need to save stubs and use them somehow. + * + */ + return CachedValuesManager.getManager(getProject()).getCachedValue( + this, + ATTRIBUTES_KEY, + () -> { + List result = findAttributesStatic(this); + return CachedValueProvider.Result.create(result, PsiModificationTracker.MODIFICATION_COUNT); + }, + false + ); + } + + /** + * @param self should be this + */ + @Nonnull + @RequiredReadAction + private static List findAttributesStatic(@Nonnull PsiElement self) { + List result = new ArrayList<>(); + for (PyAssignmentStatement statement : new PsiQuery(self).siblings(PyAssignmentStatement.class).getElements()) { + List elements = new PsiQuery(statement.getTargets()).filter(PyQualifiedExpression.class).getElements(); + for (PyQualifiedExpression targetExpression : elements) { + PyExpression qualifier = targetExpression.getQualifier(); + if (qualifier == null) { + continue; + } + PsiReference qualifierReference = qualifier.getReference(); + if (qualifierReference == null) { + continue; + } + if (qualifierReference.isReferenceTo(self)) { + result.add(statement); + } + } + } + return result; + } + + @Nonnull + @Override + @RequiredReadAction + public ProtectionLevel getProtectionLevel() { + int underscoreLevels = PyUtil.getInitialUnderscores(getName()); + for (ProtectionLevel level : ProtectionLevel.values()) { + if (level.getUnderscoreLevel() == underscoreLevels) { + return level; + } + } + return ProtectionLevel.PRIVATE; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyStringLiteralExpressionImpl.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyStringLiteralExpressionImpl.java index d972b534..8b7b6c26 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyStringLiteralExpressionImpl.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyStringLiteralExpressionImpl.java @@ -22,7 +22,7 @@ import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; import consulo.document.util.TextRange; import consulo.language.Language; import consulo.language.ast.ASTNode; @@ -31,478 +31,491 @@ import consulo.language.psi.*; import consulo.language.psi.util.PsiTreeUtil; import consulo.navigation.ItemPresentation; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.image.Image; import consulo.util.lang.Pair; +import jakarta.annotation.Nonnull; +import jakarta.annotation.Nullable; import org.intellij.lang.regexp.DefaultRegExpPropertiesProvider; import org.intellij.lang.regexp.RegExpLanguageHost; import org.intellij.lang.regexp.psi.RegExpChar; import org.intellij.lang.regexp.psi.RegExpGroup; import org.intellij.lang.regexp.psi.RegExpNamedGroupRef; -import jakarta.annotation.Nonnull; -import jakarta.annotation.Nullable; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; public class PyStringLiteralExpressionImpl extends PyElementImpl implements PyStringLiteralExpression, RegExpLanguageHost { - public static final Pattern PATTERN_ESCAPE = - Pattern.compile("\\\\(\n|\\\\|'|\"|a|b|f|n|r|t|v|([0-7]{1,3})|x([0-9a-fA-F]{1,2})" + "|N(\\{.*?\\})|u([0-9a-fA-F]{4})|U([0-9a-fA-F]{8}))"); - // -> 1 -> 2 <--> 3 <- -> 4 <--> 5 <- -> 6 <-<- - - private enum EscapeRegexGroup { - WHOLE_MATCH, - ESCAPED_SUBSTRING, - OCTAL, - HEXADECIMAL, - UNICODE_NAMED, - UNICODE_16BIT, - UNICODE_32BIT - } - - private static final Map escapeMap = initializeEscapeMap(); - private String stringValue; - private List valueTextRanges; - @Nullable - private List> myDecodedFragments; - private final DefaultRegExpPropertiesProvider myPropertiesProvider; - - private static Map initializeEscapeMap() { - Map map = new HashMap<>(); - map.put("\n", "\n"); - map.put("\\", "\\"); - map.put("'", "'"); - map.put("\"", "\""); - map.put("a", "\001"); - map.put("b", "\b"); - map.put("f", "\f"); - map.put("n", "\n"); - map.put("r", "\r"); - map.put("t", "\t"); - map.put("v", "\013"); - return map; - } - - public PyStringLiteralExpressionImpl(ASTNode astNode) { - super(astNode); - myPropertiesProvider = DefaultRegExpPropertiesProvider.getInstance(); - } - - @Override - protected void acceptPyVisitor(PyElementVisitor pyVisitor) { - pyVisitor.visitPyStringLiteralExpression(this); - } - - @Override - public void subtreeChanged() { - super.subtreeChanged(); - stringValue = null; - valueTextRanges = null; - myDecodedFragments = null; - } - - @Override - @Nonnull - public List getStringValueTextRanges() { - if (valueTextRanges == null) { - int elStart = getTextRange().getStartOffset(); - List ranges = new ArrayList<>(); - for (ASTNode node : getStringNodes()) { - TextRange range = getNodeTextRange(node.getText()); - int nodeOffset = node.getStartOffset() - elStart; - ranges.add(TextRange.from(nodeOffset + range.getStartOffset(), range.getLength())); - } - valueTextRanges = Collections.unmodifiableList(ranges); - } - return valueTextRanges; - } - - public static TextRange getNodeTextRange(final String text) { - int startOffset = getPrefixLength(text); - int delimiterLength = 1; - final String afterPrefix = text.substring(startOffset); - if (afterPrefix.startsWith("\"\"\"") || afterPrefix.startsWith("'''")) { - delimiterLength = 3; - } - final String delimiter = text.substring(startOffset, startOffset + delimiterLength); - startOffset += delimiterLength; - int endOffset = text.length(); - if (text.substring(startOffset).endsWith(delimiter)) { - endOffset -= delimiterLength; - } - return new TextRange(startOffset, endOffset); - } - - public static int getPrefixLength(String text) { - return PyStringLiteralUtil.getPrefixEndOffset(text, 0); - } - - private boolean isUnicodeByDefault() { - if (LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) { - return true; - } - final PsiFile file = getContainingFile(); - if (file instanceof PyFile) { - final PyFile pyFile = (PyFile)file; - return pyFile.hasImportFromFuture(FutureFeature.UNICODE_LITERALS); - } - return false; - } - - @Override - @Nonnull - public List> getDecodedFragments() { - if (myDecodedFragments == null) { - final List> result = new ArrayList<>(); - final int elementStart = getTextRange().getStartOffset(); - final boolean unicodeByDefault = isUnicodeByDefault(); - for (ASTNode node : getStringNodes()) { - final String text = node.getText(); - final TextRange textRange = getNodeTextRange(text); - final int offset = node.getTextRange().getStartOffset() - elementStart + textRange.getStartOffset(); - final String encoded = textRange.substring(text); - final boolean hasRawPrefix = PyStringLiteralUtil.isRawPrefix(PyStringLiteralUtil.getPrefix(text)); - final boolean hasUnicodePrefix = PyStringLiteralUtil.isUnicodePrefix(PyStringLiteralUtil.getPrefix(text)); - result.addAll(getDecodedFragments(encoded, offset, hasRawPrefix, unicodeByDefault || hasUnicodePrefix)); - } - myDecodedFragments = result; - } - return myDecodedFragments; - } - - @Override - public boolean isDocString() { - final List stringNodes = getStringNodes(); - return stringNodes.size() == 1 && stringNodes.get(0).getElementType() == PyTokenTypes.DOCSTRING; - } - - @Nonnull - private static List> getDecodedFragments(@Nonnull String encoded, int offset, boolean raw, boolean unicode) { - final List> result = new ArrayList<>(); - final Matcher escMatcher = PATTERN_ESCAPE.matcher(encoded); - int index = 0; - while (escMatcher.find(index)) { - if (index < escMatcher.start()) { - final TextRange range = TextRange.create(index, escMatcher.start()); - final TextRange offsetRange = range.shiftRight(offset); - result.add(Pair.create(offsetRange, range.substring(encoded))); - } - - final String octal = escapeRegexGroup(escMatcher, EscapeRegexGroup.OCTAL); - final String hex = escapeRegexGroup(escMatcher, EscapeRegexGroup.HEXADECIMAL); - // TODO: Implement unicode character name escapes: EscapeRegexGroup.UNICODE_NAMED - final String unicode16 = escapeRegexGroup(escMatcher, EscapeRegexGroup.UNICODE_16BIT); - final String unicode32 = escapeRegexGroup(escMatcher, EscapeRegexGroup.UNICODE_32BIT); - final String wholeMatch = escapeRegexGroup(escMatcher, EscapeRegexGroup.WHOLE_MATCH); - - final boolean escapedUnicode = raw && unicode || !raw; - - final String str; - if (!raw && octal != null) { - str = new String(new char[]{(char)Integer.parseInt(octal, 8)}); - } - else if (!raw && hex != null) { - str = new String(new char[]{(char)Integer.parseInt(hex, 16)}); - } - else if (escapedUnicode && unicode16 != null) { - str = unicode ? new String(new char[]{(char)Integer.parseInt(unicode16, 16)}) : wholeMatch; - } - else if (escapedUnicode && unicode32 != null) { - String s = wholeMatch; - if (unicode) { - try { - s = new String(Character.toChars((int)Long.parseLong(unicode32, 16))); - } - catch (IllegalArgumentException ignored) { - } + public static final Pattern PATTERN_ESCAPE = + Pattern.compile("\\\\(\n|\\\\|'|\"|a|b|f|n|r|t|v|([0-7]{1,3})|x([0-9a-fA-F]{1,2})" + "|N(\\{.*?\\})|u([0-9a-fA-F]{4})|U([0-9a-fA-F]{8}))"); + // -> 1 -> 2 <--> 3 <- -> 4 <--> 5 <- -> 6 <-<- + + private enum EscapeRegexGroup { + WHOLE_MATCH, + ESCAPED_SUBSTRING, + OCTAL, + HEXADECIMAL, + UNICODE_NAMED, + UNICODE_16BIT, + UNICODE_32BIT + } + + private static final Map escapeMap = initializeEscapeMap(); + private String stringValue; + private List valueTextRanges; + @Nullable + private List> myDecodedFragments; + private final DefaultRegExpPropertiesProvider myPropertiesProvider; + + private static Map initializeEscapeMap() { + Map map = new HashMap<>(); + map.put("\n", "\n"); + map.put("\\", "\\"); + map.put("'", "'"); + map.put("\"", "\""); + map.put("a", "\001"); + map.put("b", "\b"); + map.put("f", "\f"); + map.put("n", "\n"); + map.put("r", "\r"); + map.put("t", "\t"); + map.put("v", "\013"); + return map; + } + + public PyStringLiteralExpressionImpl(ASTNode astNode) { + super(astNode); + myPropertiesProvider = DefaultRegExpPropertiesProvider.getInstance(); + } + + @Override + protected void acceptPyVisitor(PyElementVisitor pyVisitor) { + pyVisitor.visitPyStringLiteralExpression(this); + } + + @Override + public void subtreeChanged() { + super.subtreeChanged(); + stringValue = null; + valueTextRanges = null; + myDecodedFragments = null; + } + + @Nonnull + @Override + @RequiredReadAction + public List getStringValueTextRanges() { + if (valueTextRanges == null) { + int elStart = getTextRange().getStartOffset(); + List ranges = new ArrayList<>(); + for (ASTNode node : getStringNodes()) { + TextRange range = getNodeTextRange(node.getText()); + int nodeOffset = node.getStartOffset() - elStart; + ranges.add(TextRange.from(nodeOffset + range.getStartOffset(), range.getLength())); + } + valueTextRanges = Collections.unmodifiableList(ranges); } - str = s; - } - else if (raw) { - str = wholeMatch; - } - else { - final String toReplace = escapeRegexGroup(escMatcher, EscapeRegexGroup.ESCAPED_SUBSTRING); - str = escapeMap.get(toReplace); - } - - if (str != null) { - final TextRange wholeMatchRange = TextRange.create(escMatcher.start(), escMatcher.end()); - result.add(Pair.create(wholeMatchRange.shiftRight(offset), str)); - } - - index = escMatcher.end(); - } - final TextRange range = TextRange.create(index, encoded.length()); - final TextRange offRange = range.shiftRight(offset); - result.add(Pair.create(offRange, range.substring(encoded))); - return result; - } - - @Nullable - private static String escapeRegexGroup(@Nonnull Matcher matcher, EscapeRegexGroup group) { - return matcher.group(group.ordinal()); - } - - @Override - @Nonnull - public List getStringNodes() { - return Arrays.asList(getNode().getChildren(PyTokenTypes.STRING_NODES)); - } - - @Override - public String getStringValue() { - //ASTNode child = getNode().getFirstChildNode(); - //assert child != null; - if (stringValue == null) { - final StringBuilder out = new StringBuilder(); - for (Pair fragment : getDecodedFragments()) { - out.append(fragment.getSecond()); - } - stringValue = out.toString(); - } - return stringValue; - } - - @Override - public TextRange getStringValueTextRange() { - List allRanges = getStringValueTextRanges(); - if (allRanges.size() == 1) { - return allRanges.get(0); - } - if (allRanges.size() > 1) { - return allRanges.get(0).union(allRanges.get(allRanges.size() - 1)); - } - return new TextRange(0, getTextLength()); - } - - @Override - public String toString() { - return super.toString() + ": " + getStringValue(); - } - - @Override - public boolean isValidHost() { - return true; - } - - @Override - public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { - final List nodes = getStringNodes(); - if (nodes.size() > 0) { - String text = getStringNodes().get(0).getText(); - - PyFile file = PsiTreeUtil.getParentOfType(this, PyFile.class); - if (file != null) { - IElementType type = PythonHighlightingLexer.convertStringType(getStringNodes().get(0).getElementType(), - text, - LanguageLevel.forElement(this), - file.hasImportFromFuture(FutureFeature - .UNICODE_LITERALS)); - if (PyTokenTypes.UNICODE_NODES.contains(type)) { - return PyBuiltinCache.getInstance(this).getUnicodeType(LanguageLevel.forElement(this)); + return valueTextRanges; + } + + public static TextRange getNodeTextRange(String text) { + int startOffset = getPrefixLength(text); + int delimiterLength = 1; + String afterPrefix = text.substring(startOffset); + if (afterPrefix.startsWith("\"\"\"") || afterPrefix.startsWith("'''")) { + delimiterLength = 3; + } + String delimiter = text.substring(startOffset, startOffset + delimiterLength); + startOffset += delimiterLength; + int endOffset = text.length(); + if (text.substring(startOffset).endsWith(delimiter)) { + endOffset -= delimiterLength; } - } - } - return PyBuiltinCache.getInstance(this).getBytesType(LanguageLevel.forElement(this)); - } - - @Override - @Nonnull - public PsiReference[] getReferences() { - return ReferenceProvidersRegistry.getReferencesFromProviders(this, PsiReferenceService.Hints.NO_HINTS); - } - - @Override - public ItemPresentation getPresentation() { - return new ItemPresentation() { - @Nullable - @Override - public String getPresentableText() { - return getStringValue(); - } - - @Nullable - @Override - public String getLocationString() { - return "(" + PyElementPresentation.getPackageForFile(getContainingFile()) + ")"; - } - - @Nullable - @Override - public Image getIcon() { - return AllIcons.Nodes.Variable; - } - }; - } - - @Override - public PsiLanguageInjectionHost updateText(@Nonnull String text) { - return ElementManipulators.handleContentChange(this, text); - } - - @Override - @Nonnull - public LiteralTextEscaper createLiteralTextEscaper() { - return new StringLiteralTextEscaper(this); - } - - private static class StringLiteralTextEscaper extends LiteralTextEscaper { - private final PyStringLiteralExpressionImpl myHost; - - protected StringLiteralTextEscaper(@Nonnull PyStringLiteralExpressionImpl host) { - super(host); - myHost = host; + return new TextRange(startOffset, endOffset); + } + + public static int getPrefixLength(String text) { + return PyStringLiteralUtil.getPrefixEndOffset(text, 0); } + @SuppressWarnings("SimplifiableIfStatement") + private boolean isUnicodeByDefault() { + if (LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) { + return true; + } + return getContainingFile() instanceof PyFile pyFile && pyFile.hasImportFromFuture(FutureFeature.UNICODE_LITERALS); + } + + @Nonnull @Override - public boolean decode(@Nonnull final TextRange rangeInsideHost, @Nonnull final StringBuilder outChars) { - for (Pair fragment : myHost.getDecodedFragments()) { - final TextRange encodedTextRange = fragment.getFirst(); - final TextRange intersection = encodedTextRange.intersection(rangeInsideHost); - if (intersection != null && !intersection.isEmpty()) { - final String value = fragment.getSecond(); - final String intersectedValue; - if (value.length() == 1 || value.length() == intersection.getLength()) { - intersectedValue = value; - } - else { - final int start = Math.max(0, rangeInsideHost.getStartOffset() - encodedTextRange.getStartOffset()); - final int end = Math.min(value.length(), start + intersection.getLength()); - intersectedValue = value.substring(start, end); - } - outChars.append(intersectedValue); + @RequiredReadAction + public List> getDecodedFragments() { + if (myDecodedFragments == null) { + List> result = new ArrayList<>(); + int elementStart = getTextRange().getStartOffset(); + boolean unicodeByDefault = isUnicodeByDefault(); + for (ASTNode node : getStringNodes()) { + String text = node.getText(); + TextRange textRange = getNodeTextRange(text); + int offset = node.getTextRange().getStartOffset() - elementStart + textRange.getStartOffset(); + String encoded = textRange.substring(text); + boolean hasRawPrefix = PyStringLiteralUtil.isRawPrefix(PyStringLiteralUtil.getPrefix(text)); + boolean hasUnicodePrefix = PyStringLiteralUtil.isUnicodePrefix(PyStringLiteralUtil.getPrefix(text)); + result.addAll(getDecodedFragments(encoded, offset, hasRawPrefix, unicodeByDefault || hasUnicodePrefix)); + } + myDecodedFragments = result; } - } - return true; + return myDecodedFragments; } @Override - public int getOffsetInHost(final int offsetInDecoded, @Nonnull final TextRange rangeInsideHost) { - int offset = 0; - int endOffset = -1; - for (Pair fragment : myHost.getDecodedFragments()) { - final TextRange encodedTextRange = fragment.getFirst(); - final TextRange intersection = encodedTextRange.intersection(rangeInsideHost); - if (intersection != null && !intersection.isEmpty()) { - final String value = fragment.getSecond(); - final int valueLength = value.length(); - final int intersectionLength = intersection.getLength(); - if (valueLength == 0) { - return -1; - } - else if (valueLength == 1) { - if (offset == offsetInDecoded) { - return intersection.getStartOffset(); + @RequiredReadAction + public boolean isDocString() { + List stringNodes = getStringNodes(); + return stringNodes.size() == 1 && stringNodes.get(0).getElementType() == PyTokenTypes.DOCSTRING; + } + + @Nonnull + private static List> getDecodedFragments(@Nonnull String encoded, int offset, boolean raw, boolean unicode) { + List> result = new ArrayList<>(); + Matcher escMatcher = PATTERN_ESCAPE.matcher(encoded); + int index = 0; + while (escMatcher.find(index)) { + if (index < escMatcher.start()) { + TextRange range = TextRange.create(index, escMatcher.start()); + TextRange offsetRange = range.shiftRight(offset); + result.add(Pair.create(offsetRange, range.substring(encoded))); + } + + String octal = escapeRegexGroup(escMatcher, EscapeRegexGroup.OCTAL); + String hex = escapeRegexGroup(escMatcher, EscapeRegexGroup.HEXADECIMAL); + // TODO: Implement unicode character name escapes: EscapeRegexGroup.UNICODE_NAMED + String unicode16 = escapeRegexGroup(escMatcher, EscapeRegexGroup.UNICODE_16BIT); + String unicode32 = escapeRegexGroup(escMatcher, EscapeRegexGroup.UNICODE_32BIT); + String wholeMatch = escapeRegexGroup(escMatcher, EscapeRegexGroup.WHOLE_MATCH); + + boolean escapedUnicode = raw && unicode || !raw; + + String str; + if (!raw && octal != null) { + str = new String(new char[]{(char) Integer.parseInt(octal, 8)}); + } + else if (!raw && hex != null) { + str = new String(new char[]{(char) Integer.parseInt(hex, 16)}); } - offset++; - } - else { - if (offset + intersectionLength >= offsetInDecoded) { - final int delta = offsetInDecoded - offset; - return intersection.getStartOffset() + delta; + else if (escapedUnicode && unicode16 != null) { + str = unicode ? new String(new char[]{(char) Integer.parseInt(unicode16, 16)}) : wholeMatch; } - offset += intersectionLength; - } - endOffset = intersection.getEndOffset(); + else if (escapedUnicode && unicode32 != null) { + String s = wholeMatch; + if (unicode) { + try { + s = new String(Character.toChars((int) Long.parseLong(unicode32, 16))); + } + catch (IllegalArgumentException ignored) { + } + } + str = s; + } + else if (raw) { + str = wholeMatch; + } + else { + String toReplace = escapeRegexGroup(escMatcher, EscapeRegexGroup.ESCAPED_SUBSTRING); + str = escapeMap.get(toReplace); + } + + if (str != null) { + TextRange wholeMatchRange = TextRange.create(escMatcher.start(), escMatcher.end()); + result.add(Pair.create(wholeMatchRange.shiftRight(offset), str)); + } + + index = escMatcher.end(); + } + TextRange range = TextRange.create(index, encoded.length()); + TextRange offRange = range.shiftRight(offset); + result.add(Pair.create(offRange, range.substring(encoded))); + return result; + } + + @Nullable + private static String escapeRegexGroup(@Nonnull Matcher matcher, EscapeRegexGroup group) { + return matcher.group(group.ordinal()); + } + + @Nonnull + @Override + @RequiredReadAction + public List getStringNodes() { + return Arrays.asList(getNode().getChildren(PyTokenTypes.STRING_NODES)); + } + + @Override + @RequiredReadAction + public String getStringValue() { + //ASTNode child = getNode().getFirstChildNode(); + //assert child != null; + if (stringValue == null) { + StringBuilder out = new StringBuilder(); + for (Pair fragment : getDecodedFragments()) { + out.append(fragment.getSecond()); + } + stringValue = out.toString(); } - } - // XXX: According to the real use of getOffsetInHost() it should return the correct host offset for the offset in decoded at the - // end of the range inside host, not -1 - if (offset == offsetInDecoded) { - return endOffset; - } - return -1; + return stringValue; } @Override - public boolean isOneLine() { - return false; - } - } - - @Override - public int valueOffsetToTextOffset(int valueOffset) { - return createLiteralTextEscaper().getOffsetInHost(valueOffset, getStringValueTextRange()); - } - - @Nonnull - @Override - public Class getHostClass() { - return getClass(); - } - - @Override - public boolean characterNeedsEscaping(char c) { - if (c == '#') { - return isVerboseInjection(); - } - return c == ']' || c == '}' || c == '\"' || c == '\''; - } - - private boolean isVerboseInjection() { - List> files = InjectedLanguageManager.getInstance(getProject()).getInjectedPsiFiles(this); - if (files != null) { - for (Pair file : files) { - Language language = file.getFirst().getLanguage(); - if (language == PythonVerboseRegexpLanguage.INSTANCE) { - return true; + @RequiredReadAction + public TextRange getStringValueTextRange() { + List allRanges = getStringValueTextRanges(); + if (allRanges.size() == 1) { + return allRanges.get(0); } - } - } - return false; - } - - @Override - public boolean supportsPerl5EmbeddedComments() { - return true; - } - - @Override - public boolean supportsPossessiveQuantifiers() { - return false; - } - - @Override - public boolean supportsPythonConditionalRefs() { - return true; - } - - @Override - public boolean supportsNamedGroupSyntax(RegExpGroup group) { - return group.getType() == RegExpGroup.Type.PYTHON_NAMED_GROUP; - } - - @Override - public boolean supportsNamedGroupRefSyntax(RegExpNamedGroupRef ref) { - return ref.isPythonNamedGroupRef(); - } - - @Override - public boolean supportsExtendedHexCharacter(RegExpChar regExpChar) { - return false; - } - - @Override - public boolean isValidCategory(@Nonnull String category) { - return myPropertiesProvider.isValidCategory(category); - } - - @Nonnull - @Override - public String[][] getAllKnownProperties() { - return myPropertiesProvider.getAllKnownProperties(); - } - - @Nullable - @Override - public String getPropertyDescription(@Nullable String name) { - return myPropertiesProvider.getPropertyDescription(name); - } - - @Nonnull - @Override - public String[][] getKnownCharacterClasses() { - return myPropertiesProvider.getKnownCharacterClasses(); - } + if (allRanges.size() > 1) { + return allRanges.get(0).union(allRanges.get(allRanges.size() - 1)); + } + return new TextRange(0, getTextLength()); + } + + @Override + @RequiredReadAction + public String toString() { + return super.toString() + ": " + getStringValue(); + } + + @Override + public boolean isValidHost() { + return true; + } + + @Override + @RequiredReadAction + public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { + List nodes = getStringNodes(); + if (nodes.size() > 0) { + String text = getStringNodes().get(0).getText(); + + PyFile file = PsiTreeUtil.getParentOfType(this, PyFile.class); + if (file != null) { + IElementType type = PythonHighlightingLexer.convertStringType( + getStringNodes().get(0).getElementType(), + text, + LanguageLevel.forElement(this), + file.hasImportFromFuture(FutureFeature + .UNICODE_LITERALS) + ); + if (PyTokenTypes.UNICODE_NODES.contains(type)) { + return PyBuiltinCache.getInstance(this).getUnicodeType(LanguageLevel.forElement(this)); + } + } + } + return PyBuiltinCache.getInstance(this).getBytesType(LanguageLevel.forElement(this)); + } + + @Override + @Nonnull + public PsiReference[] getReferences() { + return ReferenceProvidersRegistry.getReferencesFromProviders(this, PsiReferenceService.Hints.NO_HINTS); + } + + @Override + public ItemPresentation getPresentation() { + return new ItemPresentation() { + @Nullable + @Override + @RequiredReadAction + public String getPresentableText() { + return getStringValue(); + } + + @Nonnull + @Override + public String getLocationString() { + return "(" + PyElementPresentation.getPackageForFile(getContainingFile()) + ")"; + } + + @Nullable + @Override + public Image getIcon() { + return PlatformIconGroup.nodesVariable(); + } + }; + } + + @Override + public PsiLanguageInjectionHost updateText(@Nonnull String text) { + return ElementManipulators.handleContentChange(this, text); + } + + @Override + @Nonnull + public LiteralTextEscaper createLiteralTextEscaper() { + return new StringLiteralTextEscaper(this); + } + + private static class StringLiteralTextEscaper extends LiteralTextEscaper { + private final PyStringLiteralExpressionImpl myHost; + + protected StringLiteralTextEscaper(@Nonnull PyStringLiteralExpressionImpl host) { + super(host); + myHost = host; + } + + @Override + @RequiredReadAction + public boolean decode(@Nonnull TextRange rangeInsideHost, @Nonnull StringBuilder outChars) { + for (Pair fragment : myHost.getDecodedFragments()) { + TextRange encodedTextRange = fragment.getFirst(); + TextRange intersection = encodedTextRange.intersection(rangeInsideHost); + if (intersection != null && !intersection.isEmpty()) { + String value = fragment.getSecond(); + String intersectedValue; + if (value.length() == 1 || value.length() == intersection.getLength()) { + intersectedValue = value; + } + else { + int start = Math.max(0, rangeInsideHost.getStartOffset() - encodedTextRange.getStartOffset()); + int end = Math.min(value.length(), start + intersection.getLength()); + intersectedValue = value.substring(start, end); + } + outChars.append(intersectedValue); + } + } + return true; + } + + @Override + @RequiredReadAction + public int getOffsetInHost(int offsetInDecoded, @Nonnull TextRange rangeInsideHost) { + int offset = 0; + int endOffset = -1; + for (Pair fragment : myHost.getDecodedFragments()) { + TextRange encodedTextRange = fragment.getFirst(); + TextRange intersection = encodedTextRange.intersection(rangeInsideHost); + if (intersection != null && !intersection.isEmpty()) { + String value = fragment.getSecond(); + int valueLength = value.length(); + int intersectionLength = intersection.getLength(); + if (valueLength == 0) { + return -1; + } + else if (valueLength == 1) { + if (offset == offsetInDecoded) { + return intersection.getStartOffset(); + } + offset++; + } + else { + if (offset + intersectionLength >= offsetInDecoded) { + int delta = offsetInDecoded - offset; + return intersection.getStartOffset() + delta; + } + offset += intersectionLength; + } + endOffset = intersection.getEndOffset(); + } + } + // XXX: According to the real use of getOffsetInHost() it should return the correct host offset for the offset in decoded at the + // end of the range inside host, not -1 + if (offset == offsetInDecoded) { + return endOffset; + } + return -1; + } + + @Override + public boolean isOneLine() { + return false; + } + } + + @Override + @RequiredReadAction + public int valueOffsetToTextOffset(int valueOffset) { + return createLiteralTextEscaper().getOffsetInHost(valueOffset, getStringValueTextRange()); + } + + @Nonnull + @Override + public Class getHostClass() { + return getClass(); + } + + @Override + @RequiredReadAction + public boolean characterNeedsEscaping(char c) { + if (c == '#') { + return isVerboseInjection(); + } + return c == ']' || c == '}' || c == '\"' || c == '\''; + } + + @RequiredReadAction + private boolean isVerboseInjection() { + List> files = InjectedLanguageManager.getInstance(getProject()).getInjectedPsiFiles(this); + if (files != null) { + for (Pair file : files) { + Language language = file.getFirst().getLanguage(); + if (language == PythonVerboseRegexpLanguage.INSTANCE) { + return true; + } + } + } + return false; + } + + @Override + public boolean supportsPerl5EmbeddedComments() { + return true; + } + + @Override + public boolean supportsPossessiveQuantifiers() { + return false; + } + + @Override + public boolean supportsPythonConditionalRefs() { + return true; + } + + @Override + public boolean supportsNamedGroupSyntax(RegExpGroup group) { + return group.getType() == RegExpGroup.Type.PYTHON_NAMED_GROUP; + } + + @Override + public boolean supportsNamedGroupRefSyntax(RegExpNamedGroupRef ref) { + return ref.isPythonNamedGroupRef(); + } + + @Override + public boolean supportsExtendedHexCharacter(RegExpChar regExpChar) { + return false; + } + + @Override + public boolean isValidCategory(@Nonnull String category) { + return myPropertiesProvider.isValidCategory(category); + } + + @Nonnull + @Override + public String[][] getAllKnownProperties() { + return myPropertiesProvider.getAllKnownProperties(); + } + + @Nullable + @Override + public String getPropertyDescription(@Nullable String name) { + return myPropertiesProvider.getPropertyDescription(name); + } + + @Nonnull + @Override + public String[][] getKnownCharacterClasses() { + return myPropertiesProvider.getKnownCharacterClasses(); + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyQualifiedReference.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyQualifiedReference.java index 38228940..ded3f494 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyQualifiedReference.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyQualifiedReference.java @@ -41,7 +41,7 @@ import com.jetbrains.python.psi.types.PyClassType; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; import consulo.language.editor.completion.AutoCompletionPolicy; import consulo.language.editor.completion.CompletionUtilCore; import consulo.language.editor.completion.lookup.LookupElement; @@ -52,10 +52,10 @@ import consulo.language.psi.util.QualifiedName; import consulo.language.util.ProcessingContext; import consulo.logging.Logger; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.project.Project; import consulo.project.content.scope.ProjectScopes; import consulo.util.collection.ArrayUtil; -import consulo.util.lang.Comparing; import consulo.util.lang.StringUtil; import consulo.virtualFileSystem.VirtualFile; import jakarta.annotation.Nonnull; @@ -75,16 +75,17 @@ public PyQualifiedReference(PyQualifiedExpression element, PyResolveContext cont @Nonnull @Override + @RequiredReadAction protected List resolveInner() { PyPsiUtils.assertValid(myElement); ResolveResultList ret = new ResolveResultList(); - final String referencedName = myElement.getReferencedName(); + String referencedName = myElement.getReferencedName(); if (referencedName == null) { return ret; } - final PyExpression qualifier = myElement.getQualifier(); + PyExpression qualifier = myElement.getQualifier(); PyPsiUtils.assertValid(qualifier); if (qualifier == null) { return ret; @@ -93,8 +94,8 @@ protected List resolveInner() { // regular attributes PyType qualifierType = myContext.getTypeEvalContext().getType(qualifier); // is it a class-private name qualified by a different class? - if (PyUtil.isClassPrivateName(referencedName) && qualifierType instanceof PyClassType) { - if (isOtherClassQualifying(qualifier, (PyClassType) qualifierType)) { + if (PyUtil.isClassPrivateName(referencedName) && qualifierType instanceof PyClassType classType) { + if (isOtherClassQualifying(qualifier, classType)) { return Collections.emptyList(); } } @@ -103,7 +104,7 @@ protected List resolveInner() { qualifierType.assertValid("qualifier: " + qualifier); // resolve within the type proper AccessDirection ctx = AccessDirection.of(myElement); - final List membersOfQualifier = qualifierType.resolveMember(referencedName, qualifier, ctx, myContext); + List membersOfQualifier = qualifierType.resolveMember(referencedName, qualifier, ctx, myContext); if (membersOfQualifier == null) { return ret; // qualifier is positive that such name cannot exist in it } @@ -111,14 +112,15 @@ protected List resolveInner() { } // look for assignment of this attribute in containing function - if (qualifier instanceof PyQualifiedExpression && ret.isEmpty()) { - if (addAssignedAttributes(ret, referencedName, (PyQualifiedExpression) qualifier)) { + if (qualifier instanceof PyQualifiedExpression qualifiedExpr && ret.isEmpty()) { + if (addAssignedAttributes(ret, referencedName, qualifiedExpr)) { return ret; } } - if ((PyTypeChecker.isUnknown(qualifierType) || (qualifierType instanceof PyStructuralType && ((PyStructuralType) qualifierType).isInferredFromUsages())) && - myContext.allowImplicits() && canQualifyAnImplicitName(qualifier)) { + if ((PyTypeChecker.isUnknown(qualifierType) + || (qualifierType instanceof PyStructuralType structuralType && structuralType.isInferredFromUsages())) + && myContext.allowImplicits() && canQualifyAnImplicitName(qualifier)) { addImplicitResolveResults(referencedName, ret); } @@ -130,13 +132,13 @@ protected List resolveInner() { } private static boolean isOtherClassQualifying(PyExpression qualifier, PyClassType qualifierType) { - final List match = PyUtil.searchForWrappingMethod(qualifier, true); + List match = PyUtil.searchForWrappingMethod(qualifier, true); if (match == null) { return true; } if (match.size() > 1) { - final PyClass ourClass = qualifierType.getPyClass(); - final PsiElement theirClass = CompletionUtilCore.getOriginalOrSelf(match.get(match.size() - 1)); + PyClass ourClass = qualifierType.getPyClass(); + PsiElement theirClass = CompletionUtilCore.getOriginalOrSelf(match.get(match.size() - 1)); if (ourClass != theirClass) { return true; } @@ -145,46 +147,44 @@ private static boolean isOtherClassQualifying(PyExpression qualifier, PyClassTyp } private void addImplicitResolveResults(String referencedName, ResolveResultList ret) { - final Project project = myElement.getProject(); - final GlobalSearchScope scope = PyProjectScopeBuilder.excludeSdkTestsScope(project); - final Collection functions = PyFunctionNameIndex.find(referencedName, project, scope); - final PsiFile containingFile = myElement.getContainingFile(); - final List imports; - if (containingFile instanceof PyFile) { - imports = collectImports((PyFile) containingFile); + Project project = myElement.getProject(); + GlobalSearchScope scope = PyProjectScopeBuilder.excludeSdkTestsScope(project); + Collection functions = PyFunctionNameIndex.find(referencedName, project, scope); + List imports; + if (myElement.getContainingFile() instanceof PyFile containingFile) { + imports = collectImports(containingFile); } else { imports = Collections.emptyList(); } for (Object function : functions) { - if (!(function instanceof PyFunction)) { + if (!(function instanceof PyFunction pyFunction)) { break; } - PyFunction pyFunction = (PyFunction) function; if (pyFunction.getContainingClass() != null) { ret.add(new ImplicitResolveResult(pyFunction, getImplicitResultRate(pyFunction, imports))); } } - final Collection attributes = PyInstanceAttributeIndex.find(referencedName, project, scope); + Collection attributes = PyInstanceAttributeIndex.find(referencedName, project, scope); for (Object attribute : attributes) { - if (!(attribute instanceof PyTargetExpression)) { + if (!(attribute instanceof PyTargetExpression targetExpr)) { break; } - ret.add(new ImplicitResolveResult((PyTargetExpression) attribute, getImplicitResultRate((PyTargetExpression) attribute, imports))); + ret.add(new ImplicitResolveResult(targetExpr, getImplicitResultRate(targetExpr, imports))); } } private static List collectImports(PyFile containingFile) { List imports = new ArrayList<>(); for (PyFromImportStatement anImport : containingFile.getFromImports()) { - final QualifiedName source = anImport.getImportSourceQName(); + QualifiedName source = anImport.getImportSourceQName(); if (source != null) { imports.add(source); } } for (PyImportElement importElement : containingFile.getImportTargets()) { - final QualifiedName qName = importElement.getImportedQName(); + QualifiedName qName = importElement.getImportedQName(); if (qName != null) { imports.add(qName.removeLastComponent()); } @@ -198,12 +198,12 @@ private int getImplicitResultRate(PyElement target, List imports) rate += 200; } else { - final VirtualFile vFile = target.getContainingFile().getVirtualFile(); + VirtualFile vFile = target.getContainingFile().getVirtualFile(); if (vFile != null) { if (ProjectScopes.getProjectScope(myElement.getProject()).contains(vFile)) { rate += 80; } - final QualifiedName qName = QualifiedNameFinder.findShortestImportableQName(myElement, vFile); + QualifiedName qName = QualifiedNameFinder.findShortestImportableQName(myElement, vFile); if (qName != null && imports.contains(qName)) { rate += 70; } @@ -214,20 +214,19 @@ private int getImplicitResultRate(PyElement target, List imports) rate += 50; } } - else { - if (!(target instanceof PyFunction)) { - rate += 50; - } + else if (!(target instanceof PyFunction)) { + rate += 50; } return rate; } + @RequiredReadAction private static boolean canQualifyAnImplicitName(@Nonnull PyExpression qualifier) { - if (qualifier instanceof PyCallExpression) { - final PyExpression callee = ((PyCallExpression) qualifier).getCallee(); - if (callee instanceof PyReferenceExpression && PyNames.SUPER.equals(callee.getName())) { - final PsiElement target = ((PyReferenceExpression) callee).getReference().resolve(); - if (target != null && PyBuiltinCache.getInstance(qualifier).isBuiltin(target)) { + if (qualifier instanceof PyCallExpression call) { + PyExpression callee = call.getCallee(); + if (callee instanceof PyReferenceExpression calleeRef && PyNames.SUPER.equals(callee.getName())) { + PsiElement target = calleeRef.getReference().resolve(); + if (target != null && PyBuiltinCache.getInstance(call).isBuiltin(target)) { return false; // super() of unresolved type } } @@ -235,8 +234,8 @@ private static boolean canQualifyAnImplicitName(@Nonnull PyExpression qualifier) return true; } - private static boolean addAssignedAttributes(ResolveResultList ret, String referencedName, @Nonnull final PyQualifiedExpression qualifier) { - final QualifiedName qName = qualifier.asQualifiedName(); + private static boolean addAssignedAttributes(ResolveResultList ret, String referencedName, @Nonnull PyQualifiedExpression qualifier) { + QualifiedName qName = qualifier.asQualifiedName(); if (qName == null) { return false; } @@ -249,27 +248,25 @@ private static boolean addAssignedAttributes(ResolveResultList ret, String refer return false; } + @RequiredReadAction private void addDocReference(ResolveResultList ret, PyExpression qualifier, PyType qualifierType) { - PsiElement docstring = null; - if (qualifierType instanceof PyClassType) { - PyClass qualClass = ((PyClassType) qualifierType).getPyClass(); - docstring = qualClass.getDocStringExpression(); - } - else if (qualifierType instanceof PyModuleType) { - PyFile qualModule = ((PyModuleType) qualifierType).getModule(); - docstring = qualModule.getDocStringExpression(); - } - else if (qualifier instanceof PyReferenceExpression) { - PsiElement qual_object = ((PyReferenceExpression) qualifier).getReference(myContext).resolve(); - if (qual_object instanceof PyDocStringOwner) { - docstring = ((PyDocStringOwner) qual_object).getDocStringExpression(); - } + PsiElement docString = null; + if (qualifierType instanceof PyClassType classType) { + docString = classType.getPyClass().getDocStringExpression(); + } + else if (qualifierType instanceof PyModuleType moduleType) { + docString = moduleType.getModule().getDocStringExpression(); + } + else if (qualifier instanceof PyReferenceExpression refExpr + && refExpr.getReference(myContext).resolve() instanceof PyDocStringOwner docStringOwner) { + docString = docStringOwner.getDocStringExpression(); } - ret.poke(docstring, RatedResolveResult.RATE_HIGH); + ret.poke(docString, RatedResolveResult.RATE_HIGH); } @Nonnull @Override + @RequiredReadAction public Object[] getVariants() { PyExpression qualifier = myElement.getQualifier(); if (qualifier != null) { @@ -278,42 +275,41 @@ public Object[] getVariants() { if (qualifier == null) { return EMPTY_ARRAY; } - final PyQualifiedExpression element = CompletionUtilCore.getOriginalOrSelf(myElement); + PyQualifiedExpression element = CompletionUtilCore.getOriginalOrSelf(myElement); PyType qualifierType = TypeEvalContext.codeCompletion(element.getProject(), element.getContainingFile()).getType(qualifier); ProcessingContext ctx = new ProcessingContext(); - final Set namesAlready = new HashSet<>(); + Set namesAlready = new HashSet<>(); ctx.put(PyType.CTX_NAMES, namesAlready); - final Collection variants = new ArrayList<>(); + Collection variants = new ArrayList<>(); if (qualifierType != null) { Collections.addAll(variants, getVariantFromHasAttr(qualifier)); - if (qualifierType instanceof PyStructuralType && ((PyStructuralType) qualifierType).isInferredFromUsages()) { - final PyClassType guessedType = guessClassTypeByName(); + if (qualifierType instanceof PyStructuralType structuralType && structuralType.isInferredFromUsages()) { + PyClassType guessedType = guessClassTypeByName(); if (guessedType != null) { Collections.addAll(variants, getTypeCompletionVariants(myElement, guessedType)); } } - if (qualifier instanceof PyQualifiedExpression) { - final PyQualifiedExpression qualifierExpression = (PyQualifiedExpression) qualifier; - final QualifiedName qualifiedName = qualifierExpression.asQualifiedName(); + if (qualifier instanceof PyQualifiedExpression qualifierExpr) { + QualifiedName qualifiedName = qualifierExpr.asQualifiedName(); if (qualifiedName == null) { return variants.toArray(); } - final Collection attrs = collectAssignedAttributes(qualifiedName, qualifier); + Collection attrs = collectAssignedAttributes(qualifiedName, qualifierExpr); for (PyExpression ex : attrs) { - final String name = ex.getName(); + String name = ex.getName(); if (name != null && name.endsWith(CompletionUtilCore.DUMMY_IDENTIFIER_TRIMMED)) { continue; } if (ex instanceof PsiNamedElement && qualifierType instanceof PyClassType && name != null) { - variants.add(LookupElementBuilder.createWithSmartPointer(name, ex).withTypeText(qualifierType.getName()).withIcon(AllIcons.Nodes.Field)); + variants.add(LookupElementBuilder.createWithSmartPointer(name, ex) + .withTypeText(qualifierType.getName()) + .withIcon(PlatformIconGroup.nodesField())); } - if (ex instanceof PyReferenceExpression) { - PyReferenceExpression refExpr = (PyReferenceExpression) ex; + if (ex instanceof PyReferenceExpression refExpr) { namesAlready.add(refExpr.getReferencedName()); } - else if (ex instanceof PyTargetExpression) { - PyTargetExpression targetExpr = (PyTargetExpression) ex; + else if (ex instanceof PyTargetExpression targetExpr) { namesAlready.add(targetExpr.getName()); } } @@ -325,7 +321,7 @@ else if (ex instanceof PyTargetExpression) { } } else { - final PyClassType guessedType = guessClassTypeByName(); + PyClassType guessedType = guessClassTypeByName(); if (guessedType != null) { Collections.addAll(variants, getTypeCompletionVariants(myElement, guessedType)); } @@ -336,18 +332,18 @@ else if (ex instanceof PyTargetExpression) { } } + @RequiredReadAction private Object[] getVariantFromHasAttr(PyExpression qualifier) { Collection variants = new ArrayList<>(); PyIfStatement ifStatement = PsiTreeUtil.getParentOfType(myElement, PyIfStatement.class); while (ifStatement != null) { - PyExpression condition = ifStatement.getIfPart().getCondition(); - if (condition instanceof PyCallExpression && ((PyCallExpression) condition).isCalleeText(PyNames.HAS_ATTR)) { - PyCallExpression call = (PyCallExpression) condition; - if (call.getArguments().length > 1 && call.getArguments()[0].getText().equals(qualifier.getText())) { - PyStringLiteralExpression string = call.getArgument(1, PyStringLiteralExpression.class); - if (string != null && StringUtil.isJavaIdentifier(string.getStringValue())) { - variants.add(string.getStringValue()); - } + if (ifStatement.getIfPart().getCondition() instanceof PyCallExpression call + && call.isCalleeText(PyNames.HAS_ATTR) + && call.getArguments().length > 1 + && call.getArguments()[0].getText().equals(qualifier.getText())) { + PyStringLiteralExpression string = call.getArgument(1, PyStringLiteralExpression.class); + if (string != null && StringUtil.isJavaIdentifier(string.getStringValue())) { + variants.add(string.getStringValue()); } } ifStatement = PsiTreeUtil.getParentOfType(ifStatement, PyIfStatement.class); @@ -356,11 +352,10 @@ private Object[] getVariantFromHasAttr(PyExpression qualifier) { } @Nullable + @RequiredReadAction private PyClassType guessClassTypeByName() { - final PyExpression qualifierElement = myElement.getQualifier(); - if (qualifierElement instanceof PyReferenceExpression) { - PyReferenceExpression qualifier = (PyReferenceExpression) qualifierElement; - final String className = qualifier.getReferencedName(); + if (myElement.getQualifier() instanceof PyReferenceExpression qualifier) { + String className = qualifier.getReferencedName(); if (className != null) { Collection classes = PyClassNameIndexInsensitive.find(className, getElement().getProject()); classes = filterByImports(classes, myElement.getContainingFile()); @@ -372,6 +367,7 @@ private PyClassType guessClassTypeByName() { return null; } + @RequiredReadAction private static Collection filterByImports(Collection classes, PsiFile containingFile) { if (classes.size() <= 1) { return classes; @@ -382,7 +378,7 @@ private static Collection filterByImports(Collection classes, result.add(pyClass); } else { - final PsiElement exportedClass = ((PyFile) containingFile).getElementNamed(pyClass.getName()); + PsiElement exportedClass = ((PyFile) containingFile).getElementNamed(pyClass.getName()); if (exportedClass == pyClass) { result.add(pyClass); } @@ -395,22 +391,25 @@ private Object[] collectSeenMembers(final String text) { final Set members = new HashSet<>(); myElement.getContainingFile().accept(new PyRecursiveElementVisitor() { @Override + @RequiredReadAction public void visitPyReferenceExpression(PyReferenceExpression node) { super.visitPyReferenceExpression(node); visitPyQualifiedExpression(node); } @Override + @RequiredReadAction public void visitPyTargetExpression(PyTargetExpression node) { super.visitPyTargetExpression(node); visitPyQualifiedExpression(node); } + @RequiredReadAction private void visitPyQualifiedExpression(PyQualifiedExpression node) { if (node != myElement) { - final PyExpression qualifier = node.getQualifier(); + PyExpression qualifier = node.getQualifier(); if (qualifier != null && qualifier.getText().equals(text)) { - final String refName = node.getReferencedName(); + String refName = node.getReferencedName(); if (refName != null) { members.add(refName); } @@ -426,20 +425,20 @@ private void visitPyQualifiedExpression(PyQualifiedExpression node) { } @Nonnull - public static Collection collectAssignedAttributes(@Nonnull final QualifiedName qualifierQName, @Nonnull final PsiElement anchor) { - final Set names = new HashSet<>(); - final List results = new ArrayList<>(); + public static Collection collectAssignedAttributes(@Nonnull QualifiedName qualifierQName, @Nonnull PsiElement anchor) { + Set names = new HashSet<>(); + List results = new ArrayList<>(); for (ScopeOwner owner = ScopeUtil.getScopeOwner(anchor); owner != null; owner = ScopeUtil.getScopeOwner(owner)) { - final Scope scope = ControlFlowCache.getScope(owner); - for (final PyTargetExpression target : scope.getTargetExpressions()) { - final QualifiedName targetQName = target.asQualifiedName(); - if (targetQName != null) { - if (targetQName.getComponentCount() == qualifierQName.getComponentCount() + 1 && targetQName.matchesPrefix(qualifierQName)) { - final String name = target.getName(); - if (!names.contains(name)) { - names.add(name); - results.add(target); - } + Scope scope = ControlFlowCache.getScope(owner); + for (PyTargetExpression target : scope.getTargetExpressions()) { + QualifiedName targetQName = target.asQualifiedName(); + if (targetQName != null + && targetQName.getComponentCount() == qualifierQName.getComponentCount() + 1 + && targetQName.matchesPrefix(qualifierQName)) { + String name = target.getName(); + if (!names.contains(name)) { + names.add(name); + results.add(target); } } } @@ -448,28 +447,32 @@ public static Collection collectAssignedAttributes(@Nonnull final } @Override + @RequiredReadAction public boolean isReferenceTo(PsiElement element) { // performance: a qualified reference can never resolve to a local variable or parameter if (isLocalScope(element)) { return false; } - final String referencedName = myElement.getReferencedName(); + String referencedName = myElement.getReferencedName(); PyResolveContext resolveContext = myContext.withoutImplicits(); // Guess type eval context origin for switching to local dataflow and return type analysis if (resolveContext.getTypeEvalContext().getOrigin() == null) { - final PsiFile containingFile = myElement.getContainingFile(); - if (containingFile instanceof StubBasedPsiElement) { - assert ((StubBasedPsiElement) containingFile).getStub() == null : "Stub origin for type eval context in isReferenceTo()"; + PsiFile containingFile = myElement.getContainingFile(); + if (containingFile instanceof StubBasedPsiElement stubBasedPsiElement) { + assert stubBasedPsiElement.getStub() == null : "Stub origin for type eval context in isReferenceTo()"; } - final TypeEvalContext context = TypeEvalContext.codeAnalysis(containingFile.getProject(), containingFile); + TypeEvalContext context = TypeEvalContext.codeAnalysis(containingFile.getProject(), containingFile); resolveContext = resolveContext.withTypeEvalContext(context); } - if (element instanceof PyFunction && Comparing.equal(referencedName, ((PyFunction) element).getName()) && - ((PyFunction) element).getContainingClass() != null && !PyNames.INIT.equals(referencedName)) { - final PyExpression qualifier = myElement.getQualifier(); + if (element instanceof PyFunction function + && Objects.equals(referencedName, function.getName()) + && function.getContainingClass() != null + && !PyNames.INIT.equals(referencedName)) { + PyExpression qualifier = myElement.getQualifier(); if (qualifier != null) { - final PyType qualifierType = resolveContext.getTypeEvalContext().getType(qualifier); - if (qualifierType == null || (qualifierType instanceof PyStructuralType && ((PyStructuralType) qualifierType).isInferredFromUsages())) { + PyType qualifierType = resolveContext.getTypeEvalContext().getType(qualifier); + if (qualifierType == null + || (qualifierType instanceof PyStructuralType structuralType && structuralType.isInferredFromUsages())) { return true; } } @@ -490,20 +493,25 @@ protected PyQualifiedReference copyWithResolveContext(PyResolveContext context) return new PyQualifiedReference(myElement, context); } + @RequiredReadAction private boolean isResolvedToResult(PsiElement element, PsiElement resolveResult) { if (resolveResult instanceof PyImportedModule) { resolveResult = resolveResult.getNavigationElement(); } - if (element instanceof PsiDirectory && resolveResult instanceof PyFile && - PyNames.INIT_DOT_PY.equals(((PyFile) resolveResult).getName()) && ((PyFile) resolveResult).getContainingDirectory() == element) { + if (element instanceof PsiDirectory + && resolveResult instanceof PyFile file + && PyNames.INIT_DOT_PY.equals(file.getName()) && file.getContainingDirectory() == element) { return true; } if (resolveResult == element) { return true; } - if (resolveResult instanceof PyTargetExpression && PyUtil.isAttribute((PyTargetExpression) resolveResult) && - element instanceof PyTargetExpression && PyUtil.isAttribute((PyTargetExpression) element) && Comparing.equal(((PyTargetExpression) resolveResult).getReferencedName(), ( - (PyTargetExpression) element).getReferencedName())) { + if (resolveResult instanceof PyTargetExpression targetExpr + && PyUtil.isAttribute(targetExpr) + && element instanceof PyTargetExpression elemTargetExpr + && PyUtil.isAttribute(elemTargetExpr) + && Objects.equals(targetExpr.getReferencedName(), elemTargetExpr.getReferencedName())) { + PyClass aClass = PsiTreeUtil.getParentOfType(resolveResult, PyClass.class); PyClass bClass = PsiTreeUtil.getParentOfType(element, PyClass.class); @@ -512,18 +520,15 @@ private boolean isResolvedToResult(PsiElement element, PsiElement resolveResult) } } - if (resolvesToWrapper(element, resolveResult)) { - return true; - } - return false; + return resolvesToWrapper(element, resolveResult); } + @SuppressWarnings("SimplifiableIfStatement") private static boolean isLocalScope(PsiElement element) { if (element instanceof PyParameter) { return true; } - if (element instanceof PyTargetExpression) { - final PyTargetExpression target = (PyTargetExpression) element; + if (element instanceof PyTargetExpression target) { return !target.isQualified() && ScopeUtil.getScopeOwner(target) instanceof PyFunction; } return false; diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyReferenceImpl.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyReferenceImpl.java index 1a3dd234..548e832f 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyReferenceImpl.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/references/PyReferenceImpl.java @@ -35,7 +35,8 @@ import com.jetbrains.python.psi.resolve.RatedResolveResult; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; +import consulo.annotation.access.RequiredWriteAction; import consulo.document.util.TextRange; import consulo.language.ast.ASTNode; import consulo.language.controlFlow.Instruction; @@ -48,12 +49,12 @@ import consulo.language.psi.util.PsiTreeUtil; import consulo.language.util.IncorrectOperationException; import consulo.language.util.ProcessingContext; -import consulo.util.lang.Comparing; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.util.lang.Pair; import consulo.util.lang.StringUtil; - import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.*; import java.util.concurrent.atomic.AtomicInteger; @@ -61,752 +62,764 @@ * @author yole */ public class PyReferenceImpl implements PsiReferenceEx, PsiPolyVariantReference { - protected final PyQualifiedExpression myElement; - protected final PyResolveContext myContext; - - public PyReferenceImpl(PyQualifiedExpression element, @Nonnull PyResolveContext context) { - myElement = element; - myContext = context; - } - - @Override - public TextRange getRangeInElement() { - final ASTNode nameElement = myElement.getNameElement(); - final TextRange range = nameElement != null ? nameElement.getTextRange() : myElement.getNode().getTextRange(); - return range.shiftRight(-myElement.getNode().getStartOffset()); - } - - @Override - public PsiElement getElement() { - return myElement; - } - - /** - * Resolves reference to the most obvious point. - * Imported module names: to module file (or directory for a qualifier). - * Other identifiers: to most recent definition before this reference. - * This implementation is cached. - * - * @see #resolveInner(). - */ - @Override - @Nullable - public PsiElement resolve() { - final ResolveResult[] results = multiResolve(false); - return results.length >= 1 && !(results[0] instanceof ImplicitResolveResult) ? results[0].getElement() : null; - } - - // it is *not* final so that it can be changed in debug time. if set to false, caching is off - @SuppressWarnings("FieldCanBeLocal") - private static boolean USE_CACHE = true; - - /** - * Resolves reference to possible referred elements. - * First element is always what resolve() would return. - * Imported module names: to module file, or {directory, '__init__.py}' for a qualifier. - * todo Local identifiers: a list of definitions in the most recent compound statement - * (e.g. if X: a = 1; else: a = 2 has two definitions of a.). - * todo Identifiers not found locally: similar definitions in imported files and builtins. - * - * @see PsiPolyVariantReference#multiResolve(boolean) - */ - @Override - @Nonnull - public ResolveResult[] multiResolve(final boolean incompleteCode) { - if (USE_CACHE) { - final ResolveCache cache = ResolveCache.getInstance(getElement().getProject()); - return cache.resolveWithCaching(this, CachingResolver.INSTANCE, false, incompleteCode); - } - else { - return multiResolveInner(); - } - } - - // sorts and modifies results of resolveInner - - @Nonnull - private ResolveResult[] multiResolveInner() { - final String referencedName = myElement.getReferencedName(); - if (referencedName == null) { - return ResolveResult.EMPTY_ARRAY; - } - - List targets = resolveInner(); - if (targets.size() == 0) { - return ResolveResult.EMPTY_ARRAY; - } - - // change class results to constructor results if there are any - if (myElement.getParent() instanceof PyCallExpression) { // we're a call - ListIterator it = targets.listIterator(); - while (it.hasNext()) { - final RatedResolveResult rrr = it.next(); - final PsiElement elt = rrr.getElement(); - if (elt instanceof PyClass) { - PyClass cls = (PyClass)elt; - PyFunction init = cls.findMethodByName(PyNames.INIT, false, null); - if (init != null) { - // replace - it.set(rrr.replace(init)); - } - else { // init not found; maybe it's ancestor's - for (PyClass ancestor : cls.getAncestorClasses(myContext.getTypeEvalContext())) { - init = ancestor.findMethodByName(PyNames.INIT, false, null); - if (init != null) { - // add to results as low priority - it.add(new RatedResolveResult(RatedResolveResult.RATE_LOW, init)); - break; - } - } - } - } - } - } - - // put everything in a sorting container - List ret = RatedResolveResult.sorted(targets); - return ret.toArray(new ResolveResult[ret.size()]); - } - - @Nonnull - private static ResolveResultList resolveToLatestDefs(@Nonnull List instructions, - @Nonnull PsiElement element, - @Nonnull String name, - @Nonnull TypeEvalContext context) { - final ResolveResultList ret = new ResolveResultList(); - for (Instruction instruction : instructions) { - final PsiElement definition = instruction.getElement(); - // TODO: This check may slow down resolving, but it is the current solution to the comprehension scopes problem - if (isInnerComprehension(element, definition)) { - continue; - } - if (definition instanceof PyImportedNameDefiner && !(definition instanceof PsiNamedElement)) { - final PyImportedNameDefiner definer = (PyImportedNameDefiner)definition; - final List resolvedResults = definer.multiResolveName(name); - for (RatedResolveResult result : resolvedResults) { - final PsiElement resolved = result.getElement(); - ret.add(new ImportedResolveResult(resolved, getRate(resolved, context), definer)); - } - if (resolvedResults.isEmpty()) { - ret.add(new ImportedResolveResult(null, RatedResolveResult.RATE_NORMAL, definer)); + protected final PyQualifiedExpression myElement; + protected final PyResolveContext myContext; + + public PyReferenceImpl(PyQualifiedExpression element, @Nonnull PyResolveContext context) { + myElement = element; + myContext = context; + } + + @Nonnull + @Override + @RequiredReadAction + public TextRange getRangeInElement() { + ASTNode nameElement = myElement.getNameElement(); + TextRange range = nameElement != null ? nameElement.getTextRange() : myElement.getNode().getTextRange(); + return range.shiftRight(-myElement.getNode().getStartOffset()); + } + + @Override + @RequiredReadAction + public PsiElement getElement() { + return myElement; + } + + /** + * Resolves reference to the most obvious point. + * Imported module names: to module file (or directory for a qualifier). + * Other identifiers: to most recent definition before this reference. + * This implementation is cached. + * + * @see #resolveInner(). + */ + @Nullable + @Override + @RequiredReadAction + public PsiElement resolve() { + ResolveResult[] results = multiResolve(false); + return results.length >= 1 && !(results[0] instanceof ImplicitResolveResult) ? results[0].getElement() : null; + } + + // it is *not* final so that it can be changed in debug time. if set to false, caching is off + @SuppressWarnings("FieldCanBeLocal") + private static boolean USE_CACHE = true; + + /** + * Resolves reference to possible referred elements. + * First element is always what resolve() would return. + * Imported module names: to module file, or {directory, '__init__.py}' for a qualifier. + * todo Local identifiers: a list of definitions in the most recent compound statement + * (e.g. if X: a = 1; else: a = 2 has two definitions of a.). + * todo Identifiers not found locally: similar definitions in imported files and built-ins. + * + * @see PsiPolyVariantReference#multiResolve(boolean) + */ + @Nonnull + @Override + @RequiredReadAction + public ResolveResult[] multiResolve(boolean incompleteCode) { + if (USE_CACHE) { + ResolveCache cache = ResolveCache.getInstance(getElement().getProject()); + return cache.resolveWithCaching(this, CachingResolver.INSTANCE, false, incompleteCode); } else { - // TODO this kind of resolve contract is quite stupid - ret.poke(definer, RatedResolveResult.RATE_LOW); - } - } - else { - ret.poke(definition, getRate(definition, context)); - } - } - final ResolveResultList results = new ResolveResultList(); - for (RatedResolveResult r : ret) { - final PsiElement e = r.getElement(); - if (e == element) { - continue; - } - if (element instanceof PyTargetExpression && e != null && PyPsiUtils.isBefore(element, e)) { - continue; - } - else { - results.add(r); - } - } - - return results; - } - - private static boolean isInnerComprehension(PsiElement referenceElement, PsiElement definition) { - final PyComprehensionElement definitionComprehension = PsiTreeUtil.getParentOfType(definition, PyComprehensionElement.class); - if (definitionComprehension != null && PyUtil.isOwnScopeComprehension(definitionComprehension)) { - final PyComprehensionElement elementComprehension = PsiTreeUtil.getParentOfType(referenceElement, PyComprehensionElement.class); - if (elementComprehension == null || !PsiTreeUtil.isAncestor(definitionComprehension, elementComprehension, false)) { - return true; - } - } - return false; - } - - private static boolean isInOwnScopeComprehension(PsiElement uexpr) { - PyComprehensionElement comprehensionElement = PsiTreeUtil.getParentOfType(uexpr, PyComprehensionElement.class); - return comprehensionElement != null && PyUtil.isOwnScopeComprehension(comprehensionElement); - } - - /** - * Does actual resolution of resolve(). - * - * @return resolution result. - * @see #resolve() - */ - @Nonnull - protected List resolveInner() { - final ResolveResultList ret = new ResolveResultList(); - - final String referencedName = myElement.getReferencedName(); - if (referencedName == null) { - return ret; - } - - if (myElement instanceof PyTargetExpression) { - if (PsiTreeUtil.getParentOfType(myElement, PyComprehensionElement.class) != null) { - ret.poke(myElement, getRate(myElement, myContext.getTypeEvalContext())); - return ret; - } + return multiResolveInner(); + } } - // resolve implicit __class__ inside class function - if (myElement instanceof PyReferenceExpression && - PyNames.__CLASS__.equals(referencedName) && - LanguageLevel.forElement(myElement).isAtLeast(LanguageLevel.PYTHON30)) { - final PyFunction containingFunction = PsiTreeUtil.getParentOfType(myElement, PyFunction.class); + // sorts and modifies results of resolveInner - if (containingFunction != null) { - final PyClass containingClass = containingFunction.getContainingClass(); + @Nonnull + private ResolveResult[] multiResolveInner() { + String referencedName = myElement.getReferencedName(); + if (referencedName == null) { + return ResolveResult.EMPTY_ARRAY; + } - if (containingClass != null) { - final PyResolveProcessor processor = new PyResolveProcessor(referencedName); - PyResolveUtil.scopeCrawlUp(processor, myElement, referencedName, containingFunction); + List targets = resolveInner(); + if (targets.size() == 0) { + return ResolveResult.EMPTY_ARRAY; + } - if (processor.getElements().isEmpty()) { - ret.add(new RatedResolveResult(RatedResolveResult.RATE_NORMAL, containingClass)); - return ret; - } - } - } - } - - // here we have an unqualified expr. it may be defined: - // ...in current file - final PyResolveProcessor processor = new PyResolveProcessor(referencedName); - - // Use real context here to enable correct completion and resolve in case of PyExpressionCodeFragment - final PsiElement realContext = PyPsiUtils.getRealContext(myElement); - - final PsiElement roof = findResolveRoof(referencedName, realContext); - PyResolveUtil.scopeCrawlUp(processor, myElement, referencedName, roof); - - final List resultsFromProcessor = getResultsFromProcessor(referencedName, processor, realContext, roof); - - // resolve to module __doc__ - if (resultsFromProcessor.isEmpty() && referencedName.equals(PyNames.DOC)) { - ret.addAll(Optional.ofNullable(PyBuiltinCache.getInstance(myElement).getObjectType()) - .map(type -> type.resolveMember(referencedName, myElement, AccessDirection.of(myElement), myContext)) - .orElse(Collections.emptyList())); - - return ret; - } - - return resultsFromProcessor; - } - - protected List getResultsFromProcessor(@Nonnull String referencedName, - @Nonnull PyResolveProcessor processor, - @Nullable PsiElement realContext, - @Nullable PsiElement resolveRoof) { - boolean unreachableLocalDeclaration = false; - boolean resolveInParentScope = false; - final ResolveResultList resultList = new ResolveResultList(); - final ScopeOwner referenceOwner = ScopeUtil.getScopeOwner(realContext); - final TypeEvalContext typeEvalContext = myContext.getTypeEvalContext(); - ScopeOwner resolvedOwner = processor.getOwner(); - - if (resolvedOwner != null && !processor.getResults().isEmpty()) { - final Collection resolvedElements = processor.getElements(); - final Scope resolvedScope = ControlFlowCache.getScope(resolvedOwner); - - if (!resolvedScope.isGlobal(referencedName)) { - if (resolvedOwner == referenceOwner) { - final List instructions = PyDefUseUtil.getLatestDefs(resolvedOwner, referencedName, realContext, false, true); - // TODO: Use the results from the processor as a cache for resolving to latest defs - final ResolveResultList latestDefs = resolveToLatestDefs(instructions, realContext, referencedName, typeEvalContext); - if (!latestDefs.isEmpty()) { - return latestDefs; - } - else if (resolvedOwner instanceof PyClass || instructions.isEmpty() && allInOwnScopeComprehensions(resolvedElements)) { - resolveInParentScope = true; - } - else { - unreachableLocalDeclaration = true; - } - } - else if (referenceOwner != null) { - final Scope referenceScope = ControlFlowCache.getScope(referenceOwner); - if (referenceScope.containsDeclaration(referencedName)) { - unreachableLocalDeclaration = true; - } - } - } - } - - // TODO: Try resolve to latest defs for outer scopes starting from the last element in CFG (=> no need for a special rate for globals) - - if (!unreachableLocalDeclaration) { - if (resolveInParentScope) { - processor = new PyResolveProcessor(referencedName); - resolvedOwner = ScopeUtil.getScopeOwner(resolvedOwner); - if (resolvedOwner != null) { - PyResolveUtil.scopeCrawlUp(processor, resolvedOwner, referencedName, resolveRoof); - } - } - - for (Map.Entry entry : processor.getResults().entrySet()) { - final PsiElement resolved = entry.getKey(); - final PyImportedNameDefiner definer = entry.getValue(); - if (resolved != null) { - if (typeEvalContext.maySwitchToAST(resolved) && isInnerComprehension(realContext, resolved)) { - continue; - } - if (resolved == referenceOwner && referenceOwner instanceof PyClass) { - continue; - } - if (definer == null) { - resultList.poke(resolved, getRate(resolved, typeEvalContext)); - } - else { - resultList.poke(definer, getRate(definer, typeEvalContext)); - resultList.add(new ImportedResolveResult(resolved, getRate(resolved, typeEvalContext), definer)); - } - } - else if (definer != null) { - resultList.add(new ImportedResolveResult(null, RatedResolveResult.RATE_LOW, definer)); - } - } - - if (!resultList.isEmpty()) { - return resultList; - } - } - - return resolveByReferenceResolveProviders(); - } - - private static boolean allInOwnScopeComprehensions(@Nonnull Collection elements) { - for (PsiElement element : elements) { - if (!isInOwnScopeComprehension(element)) { + // change class results to constructor results if there are any + if (myElement.getParent() instanceof PyCallExpression) { // we're a call + ListIterator it = targets.listIterator(); + while (it.hasNext()) { + RatedResolveResult rrr = it.next(); + if (rrr.getElement() instanceof PyClass cls) { + PyFunction init = cls.findMethodByName(PyNames.INIT, false, null); + if (init != null) { + // replace + it.set(rrr.replace(init)); + } + else { // init not found; maybe it's ancestor's + for (PyClass ancestor : cls.getAncestorClasses(myContext.getTypeEvalContext())) { + init = ancestor.findMethodByName(PyNames.INIT, false, null); + if (init != null) { + // add to results as low priority + it.add(new RatedResolveResult(RatedResolveResult.RATE_LOW, init)); + break; + } + } + } + } + } + } + + // put everything in a sorting container + List ret = RatedResolveResult.sorted(targets); + return ret.toArray(new ResolveResult[ret.size()]); + } + + @Nonnull + private static ResolveResultList resolveToLatestDefs( + @Nonnull List instructions, + @Nonnull PsiElement element, + @Nonnull String name, + @Nonnull TypeEvalContext context + ) { + ResolveResultList ret = new ResolveResultList(); + for (Instruction instruction : instructions) { + PsiElement definition = instruction.getElement(); + // TODO: This check may slow down resolving, but it is the current solution to the comprehension scopes problem + if (isInnerComprehension(element, definition)) { + continue; + } + if (definition instanceof PyImportedNameDefiner definer && !(definer instanceof PsiNamedElement)) { + List resolvedResults = definer.multiResolveName(name); + for (RatedResolveResult result : resolvedResults) { + PsiElement resolved = result.getElement(); + ret.add(new ImportedResolveResult(resolved, getRate(resolved, context), definer)); + } + if (resolvedResults.isEmpty()) { + ret.add(new ImportedResolveResult(null, RatedResolveResult.RATE_NORMAL, definer)); + } + else { + // TODO this kind of resolve contract is quite stupid + ret.poke(definer, RatedResolveResult.RATE_LOW); + } + } + else { + ret.poke(definition, getRate(definition, context)); + } + } + ResolveResultList results = new ResolveResultList(); + for (RatedResolveResult r : ret) { + PsiElement e = r.getElement(); + if (e == element) { + continue; + } + if (element instanceof PyTargetExpression && e != null && PyPsiUtils.isBefore(element, e)) { + continue; + } + else { + results.add(r); + } + } + + return results; + } + + private static boolean isInnerComprehension(PsiElement referenceElement, PsiElement definition) { + PyComprehensionElement definitionComprehension = PsiTreeUtil.getParentOfType(definition, PyComprehensionElement.class); + if (definitionComprehension != null && PyUtil.isOwnScopeComprehension(definitionComprehension)) { + PyComprehensionElement elementComprehension = PsiTreeUtil.getParentOfType(referenceElement, PyComprehensionElement.class); + if (elementComprehension == null || !PsiTreeUtil.isAncestor(definitionComprehension, elementComprehension, false)) { + return true; + } + } return false; - } - } - return true; - } - - @Nonnull - private ResolveResultList resolveByReferenceResolveProviders() { - final ResolveResultList results = new ResolveResultList(); - for (PyReferenceResolveProvider provider : PyReferenceResolveProvider.EP_NAME.getExtensionList()) { - results.addAll(provider.resolveName(myElement)); - } - return results; - } - - private PsiElement findResolveRoof(String referencedName, PsiElement realContext) { - if (PyUtil.isClassPrivateName(referencedName)) { - // a class-private name; limited by either class or this file - PsiElement one = myElement; - do { - one = ScopeUtil.getScopeOwner(one); - } - while (one instanceof PyFunction); - if (one instanceof PyClass) { - PyArgumentList superClassExpressionList = ((PyClass)one).getSuperClassExpressionList(); - if (superClassExpressionList == null || !PsiTreeUtil.isAncestor(superClassExpressionList, myElement, false)) { - return one; - } - } - } - - if (myElement instanceof PyTargetExpression) { - final ScopeOwner scopeOwner = PsiTreeUtil.getParentOfType(myElement, ScopeOwner.class); - final Scope scope; - if (scopeOwner != null) { - scope = ControlFlowCache.getScope(scopeOwner); - final String name = myElement.getName(); - if (scope.isNonlocal(name)) { - final ScopeOwner nonlocalOwner = ScopeUtil.getDeclarationScopeOwner(myElement, referencedName); - if (nonlocalOwner != null && !(nonlocalOwner instanceof PyFile)) { - return nonlocalOwner; - } - } - if (!scope.isGlobal(name)) { - return scopeOwner; - } - } - } - return realContext.getContainingFile(); - } - - public static int getRate(PsiElement elt, @Nonnull TypeEvalContext context) { - int rate; - if (elt instanceof PyTargetExpression && context.maySwitchToAST(elt)) { - final PsiElement parent = elt.getParent(); - if (parent instanceof PyGlobalStatement || parent instanceof PyNonlocalStatement) { - rate = RatedResolveResult.RATE_LOW; - } - else { - rate = RatedResolveResult.RATE_NORMAL; - } - } - else if (elt instanceof PyImportedNameDefiner || elt instanceof PyReferenceExpression) { - rate = RatedResolveResult.RATE_LOW; - } - else if (elt instanceof PyFile) { - rate = RatedResolveResult.RATE_HIGH; - } - else { - rate = RatedResolveResult.RATE_NORMAL; - } - return rate; - } - - @Override - @Nonnull - public String getCanonicalText() { - return getRangeInElement().substring(getElement().getText()); - } - - @Override - public PsiElement handleElementRename(String newElementName) throws IncorrectOperationException { - ASTNode nameElement = myElement.getNameElement(); - newElementName = StringUtil.trimEnd(newElementName, PyNames.DOT_PY); - if (nameElement != null && PyNames.isIdentifier(newElementName)) { - final ASTNode newNameElement = PyUtil.createNewName(myElement, newElementName); - myElement.getNode().replaceChild(nameElement, newNameElement); - } - return myElement; - } - - @Override - @Nullable - public PsiElement bindToElement(@Nonnull PsiElement element) throws IncorrectOperationException { - return null; - } - - @Override - public boolean isReferenceTo(PsiElement element) { - if (element instanceof PsiFileSystemItem) { - // may be import via alias, so don't check if names match, do simple resolve check instead - PsiElement resolveResult = resolve(); - if (resolveResult instanceof PyImportedModule) { - resolveResult = resolveResult.getNavigationElement(); - } - if (element instanceof PsiDirectory) { - if (resolveResult instanceof PyFile) { - final PyFile file = (PyFile)resolveResult; - if (PyUtil.isPackage(file) && file.getContainingDirectory() == element) { - return true; - } + } + + private static boolean isInOwnScopeComprehension(PsiElement uExpr) { + PyComprehensionElement comprehensionElement = PsiTreeUtil.getParentOfType(uExpr, PyComprehensionElement.class); + return comprehensionElement != null && PyUtil.isOwnScopeComprehension(comprehensionElement); + } + + /** + * Does actual resolution of resolve(). + * + * @return resolution result. + * @see #resolve() + */ + @Nonnull + protected List resolveInner() { + ResolveResultList ret = new ResolveResultList(); + + String referencedName = myElement.getReferencedName(); + if (referencedName == null) { + return ret; } - else if (resolveResult instanceof PsiDirectory) { - final PsiDirectory directory = (PsiDirectory)resolveResult; - if (PyUtil.isPackage(directory, null) && directory == element) { - return true; - } - } - } - return resolveResult == element; - } - if (element instanceof PsiNamedElement) { - final String elementName = ((PsiNamedElement)element).getName(); - if ((Comparing.equal(myElement.getReferencedName(), elementName) || PyNames.INIT.equals(elementName))) { - if (!haveQualifiers(element)) { - final ScopeOwner ourScopeOwner = ScopeUtil.getScopeOwner(getElement()); - final ScopeOwner theirScopeOwner = ScopeUtil.getScopeOwner(element); - if (element instanceof PyParameter || element instanceof PyTargetExpression) { - // Check if the reference is in the same or inner scope of the element scope, not shadowed by an intermediate declaration - if (resolvesToSameLocal(element, elementName, ourScopeOwner, theirScopeOwner)) { - return true; - } - } - - final PsiElement resolveResult = resolve(); - if (resolveResult == element) { - return true; - } - // we shadow their name or they shadow ours (PY-6241) - if (resolveResult instanceof PsiNamedElement && resolveResult instanceof ScopeOwner && element instanceof ScopeOwner && - theirScopeOwner == ScopeUtil.getScopeOwner(resolveResult)) { - return true; - } + if (myElement instanceof PyTargetExpression) { + if (PsiTreeUtil.getParentOfType(myElement, PyComprehensionElement.class) != null) { + ret.poke(myElement, getRate(myElement, myContext.getTypeEvalContext())); + return ret; + } + } - if (!haveQualifiers(element) && ourScopeOwner != null && theirScopeOwner != null) { - if (resolvesToSameGlobal(element, elementName, ourScopeOwner, theirScopeOwner, resolveResult)) { - return true; + // resolve implicit __class__ inside class function + if (myElement instanceof PyReferenceExpression + && PyNames.__CLASS__.equals(referencedName) + && LanguageLevel.forElement(myElement).isAtLeast(LanguageLevel.PYTHON30)) { + PyFunction containingFunction = PsiTreeUtil.getParentOfType(myElement, PyFunction.class); + + if (containingFunction != null) { + PyClass containingClass = containingFunction.getContainingClass(); + + if (containingClass != null) { + PyResolveProcessor processor = new PyResolveProcessor(referencedName); + PyResolveUtil.scopeCrawlUp(processor, myElement, referencedName, containingFunction); + + if (processor.getElements().isEmpty()) { + ret.add(new RatedResolveResult(RatedResolveResult.RATE_NORMAL, containingClass)); + return ret; + } + } } - } + } - if (resolvesToWrapper(element, resolveResult)) { - return true; - } + // here we have an unqualified expr. it may be defined: + // ...in current file + PyResolveProcessor processor = new PyResolveProcessor(referencedName); + + // Use real context here to enable correct completion and resolve in case of PyExpressionCodeFragment + PsiElement realContext = PyPsiUtils.getRealContext(myElement); + + PsiElement roof = findResolveRoof(referencedName, realContext); + PyResolveUtil.scopeCrawlUp(processor, myElement, referencedName, roof); + + List resultsFromProcessor = getResultsFromProcessor(referencedName, processor, realContext, roof); + + // resolve to module __doc__ + if (resultsFromProcessor.isEmpty() && referencedName.equals(PyNames.DOC)) { + ret.addAll(Optional.ofNullable(PyBuiltinCache.getInstance(myElement).getObjectType()) + .map(type -> type.resolveMember(referencedName, myElement, AccessDirection.of(myElement), myContext)) + .orElse(Collections.emptyList())); + + return ret; } - if (element instanceof PyExpression) { - final PyExpression expr = (PyExpression)element; - if (PyUtil.isClassAttribute(myElement) && (PyUtil.isClassAttribute(expr) || PyUtil.isInstanceAttribute(expr))) { - final PyClass c1 = PsiTreeUtil.getParentOfType(element, PyClass.class); - final PyClass c2 = PsiTreeUtil.getParentOfType(myElement, PyClass.class); - final TypeEvalContext context = myContext.getTypeEvalContext(); - if (c1 != null && c2 != null && (c1.isSubclass(c2, context) || c2.isSubclass(c1, context))) { - return true; + + return resultsFromProcessor; + } + + protected List getResultsFromProcessor( + @Nonnull String referencedName, + @Nonnull PyResolveProcessor processor, + @Nullable PsiElement realContext, + @Nullable PsiElement resolveRoof + ) { + boolean unreachableLocalDeclaration = false; + boolean resolveInParentScope = false; + ResolveResultList resultList = new ResolveResultList(); + ScopeOwner referenceOwner = ScopeUtil.getScopeOwner(realContext); + TypeEvalContext typeEvalContext = myContext.getTypeEvalContext(); + ScopeOwner resolvedOwner = processor.getOwner(); + + if (resolvedOwner != null && !processor.getResults().isEmpty()) { + Collection resolvedElements = processor.getElements(); + Scope resolvedScope = ControlFlowCache.getScope(resolvedOwner); + + if (!resolvedScope.isGlobal(referencedName)) { + if (resolvedOwner == referenceOwner) { + List instructions = + PyDefUseUtil.getLatestDefs(resolvedOwner, referencedName, realContext, false, true); + // TODO: Use the results from the processor as a cache for resolving to latest defs + ResolveResultList latestDefs = resolveToLatestDefs(instructions, realContext, referencedName, typeEvalContext); + if (!latestDefs.isEmpty()) { + return latestDefs; + } + else if (resolvedOwner instanceof PyClass || instructions.isEmpty() && allInOwnScopeComprehensions(resolvedElements)) { + resolveInParentScope = true; + } + else { + unreachableLocalDeclaration = true; + } + } + else if (referenceOwner != null) { + Scope referenceScope = ControlFlowCache.getScope(referenceOwner); + if (referenceScope.containsDeclaration(referencedName)) { + unreachableLocalDeclaration = true; + } + } + } + } + + // TODO: Try resolve to latest defs for outer scopes starting from the last element in CFG (=> no need for a special rate for globals) + + if (!unreachableLocalDeclaration) { + if (resolveInParentScope) { + processor = new PyResolveProcessor(referencedName); + resolvedOwner = ScopeUtil.getScopeOwner(resolvedOwner); + if (resolvedOwner != null) { + PyResolveUtil.scopeCrawlUp(processor, resolvedOwner, referencedName, resolveRoof); + } + } + + for (Map.Entry entry : processor.getResults().entrySet()) { + PsiElement resolved = entry.getKey(); + PyImportedNameDefiner definer = entry.getValue(); + if (resolved != null) { + if (typeEvalContext.maySwitchToAST(resolved) && isInnerComprehension(realContext, resolved)) { + continue; + } + if (resolved == referenceOwner && referenceOwner instanceof PyClass) { + continue; + } + if (definer == null) { + resultList.poke(resolved, getRate(resolved, typeEvalContext)); + } + else { + resultList.poke(definer, getRate(definer, typeEvalContext)); + resultList.add(new ImportedResolveResult(resolved, getRate(resolved, typeEvalContext), definer)); + } + } + else if (definer != null) { + resultList.add(new ImportedResolveResult(null, RatedResolveResult.RATE_LOW, definer)); + } + } + + if (!resultList.isEmpty()) { + return resultList; } - } } - } + + return resolveByReferenceResolveProviders(); } - return false; - } - private boolean resolvesToSameLocal(PsiElement element, String elementName, ScopeOwner ourScopeOwner, ScopeOwner theirScopeOwner) { - final PsiElement ourContainer = findContainer(getElement()); - final PsiElement theirContainer = findContainer(element); - if (ourContainer != null) { - if (ourContainer == theirContainer) { - return true; - } - if (PsiTreeUtil.isAncestor(theirContainer, ourContainer, true)) { - if (ourContainer instanceof PyComprehensionElement && containsDeclaration((PyComprehensionElement)ourContainer, elementName)) { - return false; + private static boolean allInOwnScopeComprehensions(@Nonnull Collection elements) { + for (PsiElement element : elements) { + if (!isInOwnScopeComprehension(element)) { + return false; + } } + return true; + } - ScopeOwner owner = ourScopeOwner; - while (owner != theirScopeOwner && owner != null) { - if (ControlFlowCache.getScope(owner).containsDeclaration(elementName)) { - return false; - } - owner = ScopeUtil.getScopeOwner(owner); + @Nonnull + private ResolveResultList resolveByReferenceResolveProviders() { + ResolveResultList results = new ResolveResultList(); + myElement.getApplication().getExtensionPoint(PyReferenceResolveProvider.class) + .forEach(provider -> results.addAll(provider.resolveName(myElement))); + return results; + } + + private PsiElement findResolveRoof(String referencedName, PsiElement realContext) { + if (PyUtil.isClassPrivateName(referencedName)) { + // a class-private name; limited by either class or this file + PsiElement one = myElement; + do { + one = ScopeUtil.getScopeOwner(one); + } + while (one instanceof PyFunction); + if (one instanceof PyClass pyClass) { + PyArgumentList superClassExpressionList = pyClass.getSuperClassExpressionList(); + if (superClassExpressionList == null || !PsiTreeUtil.isAncestor(superClassExpressionList, myElement, false)) { + return one; + } + } } - return true; - } + if (myElement instanceof PyTargetExpression) { + ScopeOwner scopeOwner = PsiTreeUtil.getParentOfType(myElement, ScopeOwner.class); + Scope scope; + if (scopeOwner != null) { + scope = ControlFlowCache.getScope(scopeOwner); + String name = myElement.getName(); + if (scope.isNonlocal(name)) { + ScopeOwner nonLocalOwner = ScopeUtil.getDeclarationScopeOwner(myElement, referencedName); + if (nonLocalOwner != null && !(nonLocalOwner instanceof PyFile)) { + return nonLocalOwner; + } + } + if (!scope.isGlobal(name)) { + return scopeOwner; + } + } + } + return realContext.getContainingFile(); } - return false; - } - @Nullable - private static PsiElement findContainer(@Nonnull PsiElement element) { - final PyElement parent = PsiTreeUtil.getParentOfType(element, ScopeOwner.class, PyComprehensionElement.class); - if (parent instanceof PyListCompExpression && LanguageLevel.forElement(element).isOlderThan(LanguageLevel.PYTHON30)) { - return findContainer(parent); + public static int getRate(PsiElement elt, @Nonnull TypeEvalContext context) { + int rate; + if (elt instanceof PyTargetExpression && context.maySwitchToAST(elt)) { + PsiElement parent = elt.getParent(); + if (parent instanceof PyGlobalStatement || parent instanceof PyNonlocalStatement) { + rate = RatedResolveResult.RATE_LOW; + } + else { + rate = RatedResolveResult.RATE_NORMAL; + } + } + else if (elt instanceof PyImportedNameDefiner || elt instanceof PyReferenceExpression) { + rate = RatedResolveResult.RATE_LOW; + } + else if (elt instanceof PyFile) { + rate = RatedResolveResult.RATE_HIGH; + } + else { + rate = RatedResolveResult.RATE_NORMAL; + } + return rate; } - return parent; - } - private static boolean containsDeclaration(@Nonnull PyComprehensionElement comprehensionElement, @Nonnull String variableName) { - for (PyComprehensionForComponent forComponent : comprehensionElement.getForComponents()) { - final PyExpression iteratorVariable = forComponent.getIteratorVariable(); + @Nonnull + @Override + @RequiredReadAction + public String getCanonicalText() { + return getRangeInElement().substring(getElement().getText()); + } - if (iteratorVariable instanceof PyTupleExpression) { - for (PyExpression variable : (PyTupleExpression)iteratorVariable) { - if (variable instanceof PyTargetExpression && variableName.equals(variable.getName())) { - return true; - } + @Override + @RequiredWriteAction + public PsiElement handleElementRename(String newElementName) throws IncorrectOperationException { + ASTNode nameElement = myElement.getNameElement(); + newElementName = StringUtil.trimEnd(newElementName, PyNames.DOT_PY); + if (nameElement != null && PyNames.isIdentifier(newElementName)) { + ASTNode newNameElement = PyUtil.createNewName(myElement, newElementName); + myElement.getNode().replaceChild(nameElement, newNameElement); } - } - else if (iteratorVariable instanceof PyTargetExpression && variableName.equals(iteratorVariable.getName())) { - return true; - } - } - - return false; - } - - private boolean resolvesToSameGlobal(PsiElement element, - String elementName, - ScopeOwner ourScopeOwner, - ScopeOwner theirScopeOwner, - PsiElement resolveResult) { - // Handle situations when there is no top-level declaration for globals and transitive resolve doesn't help - final PsiFile ourFile = getElement().getContainingFile(); - final PsiFile theirFile = element.getContainingFile(); - if (ourFile == theirFile) { - final boolean ourIsGlobal = ControlFlowCache.getScope(ourScopeOwner).isGlobal(elementName); - final boolean theirIsGlobal = ControlFlowCache.getScope(theirScopeOwner).isGlobal(elementName); - if (ourIsGlobal && theirIsGlobal) { - return true; - } - } - if (ScopeUtil.getScopeOwner(resolveResult) == ourFile && ControlFlowCache.getScope(theirScopeOwner).isGlobal(elementName)) { - return true; - } - return false; - } - - protected boolean resolvesToWrapper(PsiElement element, PsiElement resolveResult) { - if (element instanceof PyFunction && ((PyFunction)element).getContainingClass() != null && resolveResult instanceof PyTargetExpression) { - final PyExpression assignedValue = ((PyTargetExpression)resolveResult).findAssignedValue(); - if (assignedValue instanceof PyCallExpression) { - final PyCallExpression call = (PyCallExpression)assignedValue; - final Pair functionPair = PyCallExpressionHelper.interpretAsModifierWrappingCall(call, myElement); - if (functionPair != null && functionPair.second == element) { - return true; - } - } - } - return false; - } - - private boolean haveQualifiers(PsiElement element) { - if (myElement.isQualified()) { - return true; - } - if (element instanceof PyQualifiedExpression && ((PyQualifiedExpression)element).isQualified()) { - return true; - } - return false; - } - - @Override - @Nonnull - public Object[] getVariants() { - final List ret = new ArrayList<>(); - - // Use real context here to enable correct completion and resolve in case of PyExpressionCodeFragment!!! - final PsiElement originalElement = CompletionUtilCore.getOriginalElement(myElement); - final PyQualifiedExpression element = - originalElement instanceof PyQualifiedExpression ? (PyQualifiedExpression)originalElement : myElement; - final PsiElement realContext = PyPsiUtils.getRealContext(element); - - // include our own names - final int underscores = PyUtil.getInitialUnderscores(element.getName()); - final CompletionVariantsProcessor processor = new CompletionVariantsProcessor(element); - final ScopeOwner owner = realContext instanceof ScopeOwner ? (ScopeOwner)realContext : ScopeUtil.getScopeOwner(realContext); - if (owner != null) { - PyResolveUtil.scopeCrawlUp(processor, owner, null, null); - } - - // This method is probably called for completion, so use appropriate context here - // in a call, include function's arg names - KeywordArgumentCompletionUtil.collectFunctionArgNames(element, - ret, - TypeEvalContext.codeCompletion(element.getProject(), - element.getContainingFile())); - - // include builtin names - final PyFile builtinsFile = PyBuiltinCache.getInstance(element).getBuiltinsFile(); - if (builtinsFile != null) { - PyResolveUtil.scopeCrawlUp(processor, builtinsFile, null, null); - } - - if (underscores >= 2) { - // if we're a normal module, add module's attrs - PsiFile f = realContext.getContainingFile(); - if (f instanceof PyFile) { - for (String name : PyModuleType.getPossibleInstanceMembers()) { - ret.add(LookupElementBuilder.create(name).withIcon(AllIcons.Nodes.Field)); - } - } - } - - ret.addAll(getOriginalElements(processor)); - return ret.toArray(); - } - - /** - * Throws away fake elements used for completion internally. - */ - protected List getOriginalElements(@Nonnull CompletionVariantsProcessor processor) { - final List ret = new ArrayList<>(); - for (LookupElement item : processor.getResultList()) { - final PsiElement e = item.getPsiElement(); - if (e != null) { - final PsiElement original = CompletionUtilCore.getOriginalElement(e); - if (original == null) { - continue; - } - } - ret.add(item); - } - return ret; - } - - @Override - public boolean isSoft() { - return false; - } - - @Override - public HighlightSeverity getUnresolvedHighlightSeverity(TypeEvalContext context) { - if (isBuiltInConstant()) { - return null; - } - final PyExpression qualifier = myElement.getQualifier(); - if (qualifier == null) { - return HighlightSeverity.ERROR; - } - if (context.getType(qualifier) != null) { - return HighlightSeverity.WARNING; - } - return null; - } + return myElement; + } - private boolean isBuiltInConstant() { - // TODO: generalize - String name = myElement.getReferencedName(); - return PyNames.NONE.equals(name) || "True".equals(name) || "False".equals(name); - } - - @Override - @Nullable - public String getUnresolvedDescription() { - return null; - } - - - // our very own caching resolver - - private static class CachingResolver implements ResolveCache.PolyVariantResolver { - public static CachingResolver INSTANCE = new CachingResolver(); - private ThreadLocal myNesting = new ThreadLocal() { - @Override - protected AtomicInteger initialValue() { - return new AtomicInteger(); - } - }; - - private static final int MAX_NESTING_LEVEL = 30; + @Nullable + @Override + @RequiredWriteAction + public PsiElement bindToElement(@Nonnull PsiElement element) throws IncorrectOperationException { + return null; + } @Override + @RequiredReadAction + public boolean isReferenceTo(PsiElement element) { + if (element instanceof PsiFileSystemItem) { + // may be import via alias, so don't check if names match, do simple resolve check instead + PsiElement resolveResult = resolve(); + if (resolveResult instanceof PyImportedModule importedModule) { + resolveResult = importedModule.getNavigationElement(); + } + if (element instanceof PsiDirectory) { + if (resolveResult instanceof PyFile file) { + if (PyUtil.isPackage(file) && file.getContainingDirectory() == element) { + return true; + } + } + else if (resolveResult instanceof PsiDirectory directory) { + if (PyUtil.isPackage(directory, null) && directory == element) { + return true; + } + } + } + return resolveResult == element; + } + if (element instanceof PsiNamedElement namedElem) { + String elementName = namedElem.getName(); + if ((Objects.equals(myElement.getReferencedName(), elementName) || PyNames.INIT.equals(elementName))) { + if (!haveQualifiers(namedElem)) { + ScopeOwner ourScopeOwner = ScopeUtil.getScopeOwner(getElement()); + ScopeOwner theirScopeOwner = ScopeUtil.getScopeOwner(namedElem); + if (namedElem instanceof PyParameter || namedElem instanceof PyTargetExpression) { + // Check if the reference is in the same or inner scope of the element scope, not shadowed by an intermediate declaration + if (resolvesToSameLocal(namedElem, elementName, ourScopeOwner, theirScopeOwner)) { + return true; + } + } + + PsiElement resolveResult = resolve(); + if (resolveResult == namedElem) { + return true; + } + + // we shadow their name or they shadow ours (PY-6241) + if (resolveResult instanceof PsiNamedElement + && resolveResult instanceof ScopeOwner + && namedElem instanceof ScopeOwner + && theirScopeOwner == ScopeUtil.getScopeOwner(resolveResult)) { + return true; + } + + if (!haveQualifiers(namedElem) && ourScopeOwner != null && theirScopeOwner != null) { + if (resolvesToSameGlobal(namedElem, elementName, ourScopeOwner, theirScopeOwner, resolveResult)) { + return true; + } + } + + if (resolvesToWrapper(namedElem, resolveResult)) { + return true; + } + } + if (namedElem instanceof PyExpression expr + && PyUtil.isClassAttribute(myElement) + && (PyUtil.isClassAttribute(expr) || PyUtil.isInstanceAttribute(expr))) { + PyClass c1 = PsiTreeUtil.getParentOfType(namedElem, PyClass.class); + PyClass c2 = PsiTreeUtil.getParentOfType(myElement, PyClass.class); + TypeEvalContext context = myContext.getTypeEvalContext(); + if (c1 != null && c2 != null && (c1.isSubclass(c2, context) || c2.isSubclass(c1, context))) { + return true; + } + } + } + } + return false; + } + + @RequiredReadAction + private boolean resolvesToSameLocal(PsiElement element, String elementName, ScopeOwner ourScopeOwner, ScopeOwner theirScopeOwner) { + PsiElement ourContainer = findContainer(getElement()); + PsiElement theirContainer = findContainer(element); + if (ourContainer != null) { + if (ourContainer == theirContainer) { + return true; + } + if (PsiTreeUtil.isAncestor(theirContainer, ourContainer, true)) { + if (ourContainer instanceof PyComprehensionElement comprehensionElem + && containsDeclaration(comprehensionElem, elementName)) { + return false; + } + + ScopeOwner owner = ourScopeOwner; + while (owner != theirScopeOwner && owner != null) { + if (ControlFlowCache.getScope(owner).containsDeclaration(elementName)) { + return false; + } + owner = ScopeUtil.getScopeOwner(owner); + } + + return true; + } + } + return false; + } + + @Nullable + private static PsiElement findContainer(@Nonnull PsiElement element) { + PyElement parent = PsiTreeUtil.getParentOfType(element, ScopeOwner.class, PyComprehensionElement.class); + if (parent instanceof PyListCompExpression listCompExpr && LanguageLevel.forElement(element).isOlderThan(LanguageLevel.PYTHON30)) { + return findContainer(listCompExpr); + } + return parent; + } + + private static boolean containsDeclaration(@Nonnull PyComprehensionElement comprehensionElement, @Nonnull String variableName) { + for (PyComprehensionForComponent forComponent : comprehensionElement.getForComponents()) { + PyExpression iteratorVariable = forComponent.getIteratorVariable(); + + if (iteratorVariable instanceof PyTupleExpression tuple) { + for (PyExpression variable : tuple) { + if (variable instanceof PyTargetExpression && variableName.equals(variable.getName())) { + return true; + } + } + } + else if (iteratorVariable instanceof PyTargetExpression && variableName.equals(iteratorVariable.getName())) { + return true; + } + } + + return false; + } + + @RequiredReadAction + private boolean resolvesToSameGlobal( + PsiElement element, + String elementName, + ScopeOwner ourScopeOwner, + ScopeOwner theirScopeOwner, + PsiElement resolveResult + ) { + // Handle situations when there is no top-level declaration for globals and transitive resolve doesn't help + PsiFile ourFile = getElement().getContainingFile(); + PsiFile theirFile = element.getContainingFile(); + if (ourFile == theirFile) { + boolean ourIsGlobal = ControlFlowCache.getScope(ourScopeOwner).isGlobal(elementName); + boolean theirIsGlobal = ControlFlowCache.getScope(theirScopeOwner).isGlobal(elementName); + if (ourIsGlobal && theirIsGlobal) { + return true; + } + } + return ScopeUtil.getScopeOwner(resolveResult) == ourFile + && ControlFlowCache.getScope(theirScopeOwner).isGlobal(elementName); + } + + protected boolean resolvesToWrapper(PsiElement element, PsiElement resolveResult) { + if (element instanceof PyFunction function + && function.getContainingClass() != null + && resolveResult instanceof PyTargetExpression targetExpr + && targetExpr.findAssignedValue() instanceof PyCallExpression call) { + + Pair functionPair = PyCallExpressionHelper.interpretAsModifierWrappingCall(call, myElement); + if (functionPair != null && functionPair.second == element) { + return true; + } + } + return false; + } + + @SuppressWarnings("RedundantIfStatement") + private boolean haveQualifiers(PsiElement element) { + if (myElement.isQualified()) { + return true; + } + if (element instanceof PyQualifiedExpression qualifiedExpr && qualifiedExpr.isQualified()) { + return true; + } + return false; + } + @Nonnull - public ResolveResult[] resolve(@Nonnull final PyReferenceImpl ref, final boolean incompleteCode) { - if (myNesting.get().getAndIncrement() >= MAX_NESTING_LEVEL) { - System.out.println("Stack overflow pending"); - } - try { - return ref.multiResolveInner(); - } - finally { - myNesting.get().getAndDecrement(); - } + @Override + @RequiredReadAction + public Object[] getVariants() { + List ret = new ArrayList<>(); + + // Use real context here to enable correct completion and resolve in case of PyExpressionCodeFragment!!! + PsiElement originalElement = CompletionUtilCore.getOriginalElement(myElement); + PyQualifiedExpression element = originalElement instanceof PyQualifiedExpression qualifiedExpr ? qualifiedExpr : myElement; + PsiElement realContext = PyPsiUtils.getRealContext(element); + + // include our own names + int underscores = PyUtil.getInitialUnderscores(element.getName()); + CompletionVariantsProcessor processor = new CompletionVariantsProcessor(element); + ScopeOwner owner = realContext instanceof ScopeOwner scopeOwner ? scopeOwner : ScopeUtil.getScopeOwner(realContext); + if (owner != null) { + PyResolveUtil.scopeCrawlUp(processor, owner, null, null); + } + + // This method is probably called for completion, so use appropriate context here + // in a call, include function's arg names + KeywordArgumentCompletionUtil.collectFunctionArgNames( + element, + ret, + TypeEvalContext.codeCompletion( + element.getProject(), + element.getContainingFile() + ) + ); + + // include builtin names + PyFile builtInsFile = PyBuiltinCache.getInstance(element).getBuiltinsFile(); + if (builtInsFile != null) { + PyResolveUtil.scopeCrawlUp(processor, builtInsFile, null, null); + } + + if (underscores >= 2) { + // if we're a normal module, add module's attrs + PsiFile f = realContext.getContainingFile(); + if (f instanceof PyFile) { + for (String name : PyModuleType.getPossibleInstanceMembers()) { + ret.add(LookupElementBuilder.create(name).withIcon(PlatformIconGroup.nodesField())); + } + } + } + + ret.addAll(getOriginalElements(processor)); + return ret.toArray(); + } + + /** + * Throws away fake elements used for completion internally. + */ + protected List getOriginalElements(@Nonnull CompletionVariantsProcessor processor) { + List ret = new ArrayList<>(); + for (LookupElement item : processor.getResultList()) { + PsiElement e = item.getPsiElement(); + if (e != null) { + PsiElement original = CompletionUtilCore.getOriginalElement(e); + if (original == null) { + continue; + } + } + ret.add(item); + } + return ret; } - } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; + @Override + @RequiredReadAction + public boolean isSoft() { + return false; } - if (o == null || getClass() != o.getClass()) { - return false; + + @Override + public HighlightSeverity getUnresolvedHighlightSeverity(TypeEvalContext context) { + if (isBuiltInConstant()) { + return null; + } + PyExpression qualifier = myElement.getQualifier(); + if (qualifier == null) { + return HighlightSeverity.ERROR; + } + if (context.getType(qualifier) != null) { + return HighlightSeverity.WARNING; + } + return null; } - PyReferenceImpl that = (PyReferenceImpl)o; + private boolean isBuiltInConstant() { + // TODO: generalize + String name = myElement.getReferencedName(); + return PyNames.NONE.equals(name) || "True".equals(name) || "False".equals(name); + } - if (!myElement.equals(that.myElement)) { - return false; + @Override + @Nullable + public String getUnresolvedDescription() { + return null; } - if (!myContext.equals(that.myContext)) { - return false; + + + // our very own caching resolver + + private static class CachingResolver implements ResolveCache.PolyVariantResolver { + public static CachingResolver INSTANCE = new CachingResolver(); + private ThreadLocal myNesting = new ThreadLocal<>() { + @Override + protected AtomicInteger initialValue() { + return new AtomicInteger(); + } + }; + + private static final int MAX_NESTING_LEVEL = 30; + + @Override + @Nonnull + public ResolveResult[] resolve(@Nonnull PyReferenceImpl ref, boolean incompleteCode) { + if (myNesting.get().getAndIncrement() >= MAX_NESTING_LEVEL) { + System.out.println("Stack overflow pending"); + } + try { + return ref.multiResolveInner(); + } + finally { + myNesting.get().getAndDecrement(); + } + } } - return true; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - @Override - public int hashCode() { - return myElement.hashCode(); - } + PyReferenceImpl that = (PyReferenceImpl) o; - protected static Object[] getTypeCompletionVariants(PyExpression pyExpression, PyType type) { - ProcessingContext context = new ProcessingContext(); - context.put(PyType.CTX_NAMES, new HashSet<>()); - return type.getCompletionVariants(pyExpression.getName(), pyExpression, context); - } + return myElement.equals(that.myElement) + && myContext.equals(that.myContext); + } + + @Override + public int hashCode() { + return myElement.hashCode(); + } + + protected static Object[] getTypeCompletionVariants(PyExpression pyExpression, PyType type) { + ProcessingContext context = new ProcessingContext(); + context.put(PyType.CTX_NAMES, new HashSet<>()); + return type.getCompletionVariants(pyExpression.getName(), pyExpression, context); + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/types/PyStructuralType.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/types/PyStructuralType.java index e5595f11..8dbdd514 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/types/PyStructuralType.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/types/PyStructuralType.java @@ -20,14 +20,15 @@ import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.resolve.RatedResolveResult; import com.jetbrains.python.psi.types.PyType; -import consulo.application.AllIcons; import consulo.language.editor.completion.lookup.LookupElementBuilder; import consulo.language.psi.PsiElement; import consulo.language.util.ProcessingContext; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.util.lang.StringUtil; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -37,60 +38,62 @@ * @author vlan */ public class PyStructuralType implements PyType { - @Nonnull - private final Set myAttributes; - private final boolean myInferredFromUsages; + @Nonnull + private final Set myAttributes; + private final boolean myInferredFromUsages; - public PyStructuralType(@Nonnull Set attributes, boolean inferredFromUsages) { - myAttributes = attributes; - myInferredFromUsages = inferredFromUsages; - } + public PyStructuralType(@Nonnull Set attributes, boolean inferredFromUsages) { + myAttributes = attributes; + myInferredFromUsages = inferredFromUsages; + } - @Nullable - @Override - public List resolveMember(@Nonnull String name, - @Nullable PyExpression location, - @Nonnull AccessDirection direction, - @Nonnull PyResolveContext resolveContext) { - return Collections.emptyList(); - } + @Nullable + @Override + public List resolveMember( + @Nonnull String name, + @Nullable PyExpression location, + @Nonnull AccessDirection direction, + @Nonnull PyResolveContext resolveContext + ) { + return Collections.emptyList(); + } - @Override - public Object[] getCompletionVariants(String completionPrefix, PsiElement location, ProcessingContext context) { - final List variants = new ArrayList<>(); - for (String attribute : myAttributes) { - if (!attribute.equals(completionPrefix)) { - variants.add(LookupElementBuilder.create(attribute).withIcon(AllIcons.Nodes.Field)); - } + @Override + public Object[] getCompletionVariants(String completionPrefix, PsiElement location, ProcessingContext context) { + List variants = new ArrayList<>(); + for (String attribute : myAttributes) { + if (!attribute.equals(completionPrefix)) { + variants.add(LookupElementBuilder.create(attribute).withIcon(PlatformIconGroup.nodesField())); + } + } + return variants.toArray(); } - return variants.toArray(); - } - @Nullable - @Override - public String getName() { - return "{" + StringUtil.join(myAttributes, ", ") + "}"; - } + @Nullable + @Override + public String getName() { + return "{" + StringUtil.join(myAttributes, ", ") + "}"; + } - @Override - public boolean isBuiltin() { - return false; - } + @Override + public boolean isBuiltin() { + return false; + } - @Override - public void assertValid(String message) { - } + @Override + public void assertValid(String message) { + } - @Override - public String toString() { - return "PyStructuralType(" + StringUtil.join(myAttributes, ", ") + ")"; - } + @Override + public String toString() { + return "PyStructuralType(" + StringUtil.join(myAttributes, ", ") + ")"; + } - public boolean isInferredFromUsages() { - return myInferredFromUsages; - } + public boolean isInferredFromUsages() { + return myInferredFromUsages; + } - public Set getAttributeNames() { - return myAttributes; - } + public Set getAttributeNames() { + return myAttributes; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/refactoring/classes/ui/PyMemberSelectionTable.java b/python-impl/src/main/java/com/jetbrains/python/impl/refactoring/classes/ui/PyMemberSelectionTable.java index 06bd5e0d..f9fb3e90 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/refactoring/classes/ui/PyMemberSelectionTable.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/refactoring/classes/ui/PyMemberSelectionTable.java @@ -18,11 +18,10 @@ import com.jetbrains.python.impl.refactoring.classes.membersManager.PyMemberInfo; import com.jetbrains.python.psi.PyElement; import com.jetbrains.python.psi.PyFunction; -import consulo.application.AllIcons; -import consulo.language.editor.refactoring.RefactoringBundle; import consulo.language.editor.refactoring.classMember.MemberInfoModel; +import consulo.language.editor.refactoring.localize.RefactoringLocalize; import consulo.language.editor.refactoring.ui.AbstractMemberSelectionTable; -import consulo.language.psi.PsiElement; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.image.Image; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; @@ -32,40 +31,36 @@ /** * @author Dennis.Ushakov */ -public class PyMemberSelectionTable extends AbstractMemberSelectionTable> -{ - private static final String ABSTRACT_TITLE = RefactoringBundle.message("make.abstract"); - private final boolean mySupportAbstract; +public class PyMemberSelectionTable extends AbstractMemberSelectionTable> { + private final boolean mySupportAbstract; - public PyMemberSelectionTable(@Nonnull final List> memberInfos, @Nullable final MemberInfoModel> model, final boolean supportAbstract) - { - super(memberInfos, model, (supportAbstract ? ABSTRACT_TITLE : null)); - mySupportAbstract = supportAbstract; - } + public PyMemberSelectionTable( + @Nonnull List> memberInfos, + @Nullable MemberInfoModel> model, + boolean supportAbstract + ) { + super(memberInfos, model, (supportAbstract ? RefactoringLocalize.makeAbstract().get() : null)); + mySupportAbstract = supportAbstract; + } - @Nullable - @Override - protected Object getAbstractColumnValue(final PyMemberInfo memberInfo) - { - //TODO: Too many logic, move to presenters - return (mySupportAbstract && memberInfo.isChecked() && myMemberInfoModel.isAbstractEnabled(memberInfo)) ? memberInfo.isToAbstract() : null; - } + @Nullable + @Override + protected Object getAbstractColumnValue(PyMemberInfo memberInfo) { + //TODO: Too many logic, move to presenters + return (mySupportAbstract && memberInfo.isChecked() && myMemberInfoModel.isAbstractEnabled(memberInfo)) ? memberInfo.isToAbstract() : null; + } - @Override - protected boolean isAbstractColumnEditable(final int rowIndex) - { - return mySupportAbstract && myMemberInfoModel.isAbstractEnabled(myMemberInfos.get(rowIndex)); - } + @Override + protected boolean isAbstractColumnEditable(int rowIndex) { + return mySupportAbstract && myMemberInfoModel.isAbstractEnabled(myMemberInfos.get(rowIndex)); + } - @Override - protected Image getOverrideIcon(PyMemberInfo memberInfo) - { - final PsiElement member = memberInfo.getMember(); - Image overrideIcon = EMPTY_OVERRIDE_ICON; - if(member instanceof PyFunction && memberInfo.getOverrides() != null && memberInfo.getOverrides()) - { - overrideIcon = AllIcons.General.OverridingMethod; - } - return overrideIcon; - } + @Override + protected Image getOverrideIcon(PyMemberInfo memberInfo) { + Image overrideIcon = EMPTY_OVERRIDE_ICON; + if (memberInfo.getMember() instanceof PyFunction && memberInfo.getOverrides() != null && memberInfo.getOverrides()) { + overrideIcon = PlatformIconGroup.gutterOverridingmethod(); + } + return overrideIcon; + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyFieldsFilter.java b/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyFieldsFilter.java index beb2f4a4..c176d460 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyFieldsFilter.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyFieldsFilter.java @@ -13,52 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.jetbrains.python.impl.structureView; -import consulo.application.AllIcons; import consulo.fileEditor.structureView.tree.ActionPresentation; import consulo.fileEditor.structureView.tree.ActionPresentationData; import consulo.fileEditor.structureView.tree.Filter; import consulo.fileEditor.structureView.tree.TreeElement; -import consulo.ide.IdeBundle; +import consulo.ide.localize.IdeLocalize; +import consulo.platform.base.icon.PlatformIconGroup; import jakarta.annotation.Nonnull; /** * @author vlan */ public class PyFieldsFilter implements Filter { - private static final String ID = "SHOW_FIELDS"; + private static final String ID = "SHOW_FIELDS"; - @Override - public boolean isReverted() { - return true; - } + @Override + public boolean isReverted() { + return true; + } - @Override - public boolean isVisible(TreeElement treeNode) { - if (treeNode instanceof PyStructureViewElement) { - final PyStructureViewElement sve = (PyStructureViewElement)treeNode; - return !sve.isField(); + @Override + public boolean isVisible(TreeElement treeNode) { + return !(treeNode instanceof PyStructureViewElement sve && sve.isField()); } - return true; - } - @Nonnull - @Override - public String getName() { - return ID; - } + @Nonnull + @Override + public String getName() { + return ID; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } - @Nonnull - @Override - public ActionPresentation getPresentation() { - return new ActionPresentationData(IdeBundle.message("action.structureview.show.fields"), null, AllIcons.Nodes.Field); - } + @Nonnull + @Override + public ActionPresentation getPresentation() { + return new ActionPresentationData(IdeLocalize.actionStructureviewShowFields().get(), null, PlatformIconGroup.nodesField()); + } } diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyInheritedMembersFilter.java b/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyInheritedMembersFilter.java index 5f956384..27aba855 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyInheritedMembersFilter.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/structureView/PyInheritedMembersFilter.java @@ -13,66 +13,62 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.jetbrains.python.impl.structureView; -import consulo.application.AllIcons; import consulo.fileEditor.structureView.tree.ActionPresentation; import consulo.fileEditor.structureView.tree.ActionPresentationData; import consulo.fileEditor.structureView.tree.FileStructureFilter; import consulo.fileEditor.structureView.tree.TreeElement; -import consulo.ide.IdeBundle; +import consulo.ide.localize.IdeLocalize; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.ex.action.Shortcut; import consulo.ui.ex.keymap.KeymapManager; - import jakarta.annotation.Nonnull; /** * @author vlan */ public class PyInheritedMembersFilter implements FileStructureFilter { - private static final String ID = "SHOW_INHERITED"; + private static final String ID = "SHOW_INHERITED"; - @Override - public boolean isReverted() { - return true; - } + @Override + public boolean isReverted() { + return true; + } - @Override - public boolean isVisible(TreeElement treeNode) { - if (treeNode instanceof PyStructureViewElement) { - final PyStructureViewElement sve = (PyStructureViewElement)treeNode; - return !sve.isInherited(); + @Override + public boolean isVisible(TreeElement treeNode) { + return !(treeNode instanceof PyStructureViewElement sve && sve.isInherited()); } - return true; - } - @Nonnull - @Override - public String getName() { - return ID; - } + @Nonnull + @Override + public String getName() { + return ID; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } - @Nonnull - @Override - public ActionPresentation getPresentation() { - return new ActionPresentationData(IdeBundle.message("action.structureview.show.inherited"), - null, - AllIcons.Hierarchy.Supertypes); - } + @Nonnull + @Override + public ActionPresentation getPresentation() { + return new ActionPresentationData( + IdeLocalize.actionStructureviewShowInherited().get(), + null, + PlatformIconGroup.hierarchySupertypes() + ); + } - @Override - public String getCheckBoxText() { - return IdeBundle.message("file.structure.toggle.show.inherited"); - } + @Override + public String getCheckBoxText() { + return IdeLocalize.fileStructureToggleShowInherited().get(); + } - @Override - public Shortcut[] getShortcut() { - return KeymapManager.getInstance().getActiveKeymap().getShortcuts("FileStructurePopup"); - } + @Override + public Shortcut[] getShortcut() { + return KeymapManager.getInstance().getActiveKeymap().getShortcuts("FileStructurePopup"); + } } diff --git a/python-psi-api/src/main/java/com/jetbrains/python/codeInsight/PyCustomMember.java b/python-psi-api/src/main/java/com/jetbrains/python/codeInsight/PyCustomMember.java index 75a1d039..350633f0 100644 --- a/python-psi-api/src/main/java/com/jetbrains/python/codeInsight/PyCustomMember.java +++ b/python-psi-api/src/main/java/com/jetbrains/python/codeInsight/PyCustomMember.java @@ -22,7 +22,7 @@ import com.jetbrains.python.psi.types.PyClassType; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import consulo.application.AllIcons; +import consulo.annotation.access.RequiredReadAction; import consulo.application.util.CachedValueProvider; import consulo.application.util.CachedValuesManager; import consulo.application.util.ParameterizedCachedValue; @@ -32,12 +32,14 @@ import consulo.language.psi.PsiElement; import consulo.language.psi.PsiModificationTracker; import consulo.language.psi.PsiReference; +import consulo.platform.base.icon.PlatformIconGroup; import consulo.ui.image.Image; import consulo.util.dataholder.Key; import consulo.util.dataholder.UserDataHolderBase; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.function.Function; /** @@ -47,255 +49,254 @@ * @author Dennis.Ushakov */ public class PyCustomMember extends UserDataHolderBase { - private static final Key> RESOLVE = Key.create("resolve"); - private final String myName; - private final boolean myResolveToInstance; - private final Function myTypeCallback; - @Nullable - private final String myTypeName; - - private final PsiElement myTarget; - private PyPsiPath myPsiPath; - - boolean myFunction = false; - - /** - * Force resolving to {@link MyInstanceElement} even if element is function - */ - private boolean myAlwaysResolveToCustomElement; - private Image myIcon = AllIcons.Nodes.Method; - private PyCustomMemberTypeInfo myCustomTypeInfo; - - public PyCustomMember(@Nonnull final String name, @Nullable final String type, final boolean resolveToInstance) { - myName = name; - myResolveToInstance = resolveToInstance; - myTypeName = type; - - myTarget = null; - myTypeCallback = null; - } - - public PyCustomMember(@Nonnull final String name) { - myName = name; - myResolveToInstance = false; - myTypeName = null; - - myTarget = null; - myTypeCallback = null; - } - - public PyCustomMember(@Nonnull final String name, @Nullable final String type, final Function typeCallback) { - myName = name; - - myResolveToInstance = false; - myTypeName = type; - - myTarget = null; - myTypeCallback = typeCallback; - } - - public PyCustomMember(@Nonnull final String name, @Nullable final PsiElement target, @Nullable String typeName) { - myName = name; - myTarget = target; - myResolveToInstance = false; - myTypeName = typeName; - myTypeCallback = null; - } - - public PyCustomMember(@Nonnull final String name, @Nullable final PsiElement target) { - this(name, target, null); - } - - public PyCustomMember resolvesTo(String moduleQName) { - myPsiPath = new PyPsiPath.ToFile(moduleQName); - return this; - } - - public PyCustomMember resolvesToClass(String classQName) { - myPsiPath = new PyPsiPath.ToClassQName(classQName); - return this; - } - - /** - * Force resolving to {@link MyInstanceElement} even if element is function - */ - @Nonnull - public final PyCustomMember alwaysResolveToCustomElement() { - myAlwaysResolveToCustomElement = true; - return this; - } - - public PyCustomMember toClass(String name) { - myPsiPath = new PyPsiPath.ToClass(myPsiPath, name); - return this; - } - - public PyCustomMember toFunction(String name) { - myPsiPath = new PyPsiPath.ToFunction(myPsiPath, name); - return this; - } - - public PyCustomMember toFunctionRecursive(String name) { - myPsiPath = new PyPsiPath.ToFunctionRecursive(myPsiPath, name); - return this; - } - - public PyCustomMember toClassAttribute(String name) { - myPsiPath = new PyPsiPath.ToClassAttribute(myPsiPath, name); - return this; - } - - public PyCustomMember toCall(String name, String... args) { - myPsiPath = new PyPsiPath.ToCall(myPsiPath, name, args); - return this; - } - - public PyCustomMember toAssignment(String assignee) { - myPsiPath = new PyPsiPath.ToAssignment(myPsiPath, assignee); - return this; - } - - public PyCustomMember toPsiElement(final PsiElement psiElement) { - myPsiPath = new PyPsiPath() { - - @Override - public PsiElement resolve(PsiElement module) { - return psiElement; - } - }; - return this; - } - - public String getName() { - return myName; - } - - public Image getIcon() { - if (myTarget != null) { - return IconDescriptorUpdaters.getIcon(myTarget, 0); + private static final Key> RESOLVE = Key.create("resolve"); + private final String myName; + private final boolean myResolveToInstance; + private final Function myTypeCallback; + @Nullable + private final String myTypeName; + + private final PsiElement myTarget; + private PyPsiPath myPsiPath; + + boolean myFunction = false; + + /** + * Force resolving to {@link MyInstanceElement} even if element is function + */ + private boolean myAlwaysResolveToCustomElement; + private Image myIcon = PlatformIconGroup.nodesMethod(); + private PyCustomMemberTypeInfo myCustomTypeInfo; + + public PyCustomMember(@Nonnull String name, @Nullable String type, boolean resolveToInstance) { + myName = name; + myResolveToInstance = resolveToInstance; + myTypeName = type; + + myTarget = null; + myTypeCallback = null; } - return myIcon; - } - @Nullable - public PsiElement resolve(@Nonnull final PsiElement context) { + public PyCustomMember(@Nonnull String name) { + myName = name; + myResolveToInstance = false; + myTypeName = null; - if (myTarget != null) { - return myTarget; + myTarget = null; + myTypeCallback = null; } - PyClass targetClass = null; - if (myTypeName != null) { + public PyCustomMember(@Nonnull String name, @Nullable String type, Function typeCallback) { + myName = name; - final ParameterizedCachedValueProvider provider = new ParameterizedCachedValueProvider() { - @Nullable - @Override - public CachedValueProvider.Result compute(final PsiElement param) { - final PyClass result = PyPsiFacade.getInstance(param.getProject()).createClassByQName(myTypeName, param); - return CachedValueProvider.Result.create(result, PsiModificationTracker.MODIFICATION_COUNT); + myResolveToInstance = false; + myTypeName = type; + + myTarget = null; + myTypeCallback = typeCallback; + } + + public PyCustomMember(@Nonnull String name, @Nullable PsiElement target, @Nullable String typeName) { + myName = name; + myTarget = target; + myResolveToInstance = false; + myTypeName = typeName; + myTypeCallback = null; + } + + public PyCustomMember(@Nonnull String name, @Nullable PsiElement target) { + this(name, target, null); + } + + public PyCustomMember resolvesTo(String moduleQName) { + myPsiPath = new PyPsiPath.ToFile(moduleQName); + return this; + } + + public PyCustomMember resolvesToClass(String classQName) { + myPsiPath = new PyPsiPath.ToClassQName(classQName); + return this; + } + + /** + * Force resolving to {@link MyInstanceElement} even if element is function + */ + @Nonnull + public final PyCustomMember alwaysResolveToCustomElement() { + myAlwaysResolveToCustomElement = true; + return this; + } + + public PyCustomMember toClass(String name) { + myPsiPath = new PyPsiPath.ToClass(myPsiPath, name); + return this; + } + + public PyCustomMember toFunction(String name) { + myPsiPath = new PyPsiPath.ToFunction(myPsiPath, name); + return this; + } + + public PyCustomMember toFunctionRecursive(String name) { + myPsiPath = new PyPsiPath.ToFunctionRecursive(myPsiPath, name); + return this; + } + + public PyCustomMember toClassAttribute(String name) { + myPsiPath = new PyPsiPath.ToClassAttribute(myPsiPath, name); + return this; + } + + public PyCustomMember toCall(String name, String... args) { + myPsiPath = new PyPsiPath.ToCall(myPsiPath, name, args); + return this; + } + + public PyCustomMember toAssignment(String assignee) { + myPsiPath = new PyPsiPath.ToAssignment(myPsiPath, assignee); + return this; + } + + public PyCustomMember toPsiElement(final PsiElement psiElement) { + myPsiPath = new PyPsiPath() { + + @Override + public PsiElement resolve(PsiElement module) { + return psiElement; + } + }; + return this; + } + + public String getName() { + return myName; + } + + @RequiredReadAction + public Image getIcon() { + if (myTarget != null) { + return IconDescriptorUpdaters.getIcon(myTarget, 0); } - }; - targetClass = - CachedValuesManager.getManager(context.getProject()).getParameterizedCachedValue(this, RESOLVE, provider, false, context); + return myIcon; } - final PsiElement resolveTarget = findResolveTarget(context); - if (resolveTarget instanceof PyFunction && !myAlwaysResolveToCustomElement) { - return resolveTarget; + + @Nullable + public PsiElement resolve(@Nonnull PsiElement context) { + + if (myTarget != null) { + return myTarget; + } + + PyClass targetClass = null; + if (myTypeName != null) { + + ParameterizedCachedValueProvider provider = new ParameterizedCachedValueProvider<>() { + @Nullable + @Override + public CachedValueProvider.Result compute(PsiElement param) { + PyClass result = PyPsiFacade.getInstance(param.getProject()).createClassByQName(myTypeName, param); + return CachedValueProvider.Result.create(result, PsiModificationTracker.MODIFICATION_COUNT); + } + }; + targetClass = + CachedValuesManager.getManager(context.getProject()).getParameterizedCachedValue(this, RESOLVE, provider, false, context); + } + PsiElement resolveTarget = findResolveTarget(context); + if (resolveTarget instanceof PyFunction && !myAlwaysResolveToCustomElement) { + return resolveTarget; + } + if (resolveTarget != null || targetClass != null) { + return new MyInstanceElement(targetClass, context, resolveTarget); + } + return null; } - if (resolveTarget != null || targetClass != null) { - return new MyInstanceElement(targetClass, context, resolveTarget); + + @Nullable + private PsiElement findResolveTarget(@Nonnull PsiElement context) { + if (myPsiPath != null) { + return myPsiPath.resolve(context); + } + return null; } - return null; - } - @Nullable - private PsiElement findResolveTarget(@Nonnull PsiElement context) { - if (myPsiPath != null) { - return myPsiPath.resolve(context); + @Nullable + public String getShortType() { + if (myTypeName == null) { + return null; + } + int pos = myTypeName.lastIndexOf('.'); + return myTypeName.substring(pos + 1); } - return null; - } - @Nullable - public String getShortType() { - if (myTypeName == null) { - return null; + public PyCustomMember asFunction() { + myFunction = true; + return this; } - int pos = myTypeName.lastIndexOf('.'); - return myTypeName.substring(pos + 1); - } - - public PyCustomMember asFunction() { - myFunction = true; - return this; - } - - public boolean isFunction() { - return myFunction; - } - - /** - * Checks if some reference points to this element - * - * @param reference reference to check - * @return true if reference points to it - */ - public final boolean isReferenceToMe(@Nonnull final PsiReference reference) { - final PsiElement element = reference.resolve(); - if (!(element instanceof MyInstanceElement)) { - return false; + + public boolean isFunction() { + return myFunction; } - return ((MyInstanceElement)element).getThis().equals(this); - } - - /** - * @param icon icon to use (will be used method icon otherwise) - */ - public PyCustomMember withIcon(@Nonnull final Image icon) { - myIcon = icon; - return this; - } - - /** - * Adds custom info to type if class has {@link #myTypeName} set. - * Info could be later obtained by key. - * - * @param customInfo custom info to add - */ - public PyCustomMember withCustomTypeInfo(@Nonnull final PyCustomMemberTypeInfo customInfo) { - if (myTypeName != null) { - throw new IllegalArgumentException("Cant add custom type info if no type provided"); + + /** + * Checks if some reference points to this element + * + * @param reference reference to check + * @return true if reference points to it + */ + @RequiredReadAction + public final boolean isReferenceToMe(@Nonnull PsiReference reference) { + return reference.resolve() instanceof MyInstanceElement instanceElem && instanceElem.getThis().equals(this); } - myCustomTypeInfo = customInfo; - return this; - } - - private class MyInstanceElement extends ASTWrapperPsiElement implements PyTypedElement { - private final PyClass myClass; - private final PsiElement myContext; - - public MyInstanceElement(PyClass clazz, PsiElement context, PsiElement resolveTarget) { - super(resolveTarget != null ? resolveTarget.getNode() : clazz.getNode()); - myClass = clazz; - myContext = context; + + /** + * @param icon icon to use (will be used method icon otherwise) + */ + public PyCustomMember withIcon(@Nonnull Image icon) { + myIcon = icon; + return this; } - private PyCustomMember getThis() { - return PyCustomMember.this; + /** + * Adds custom info to type if class has {@link #myTypeName} set. + * Info could be later obtained by key. + * + * @param customInfo custom info to add + */ + public PyCustomMember withCustomTypeInfo(@Nonnull PyCustomMemberTypeInfo customInfo) { + if (myTypeName != null) { + throw new IllegalArgumentException("Cant add custom type info if no type provided"); + } + myCustomTypeInfo = customInfo; + return this; } - public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { - if (myTypeCallback != null) { - return myTypeCallback.apply(myContext); - } - else if (myClass != null) { - final PyClassType type = PyPsiFacade.getInstance(getProject()).createClassType(myClass, !myResolveToInstance); - if (myCustomTypeInfo != null) { - myCustomTypeInfo.fill(type); + private class MyInstanceElement extends ASTWrapperPsiElement implements PyTypedElement { + private final PyClass myClass; + private final PsiElement myContext; + + public MyInstanceElement(PyClass clazz, PsiElement context, PsiElement resolveTarget) { + super(resolveTarget != null ? resolveTarget.getNode() : clazz.getNode()); + myClass = clazz; + myContext = context; + } + + private PyCustomMember getThis() { + return PyCustomMember.this; + } + + @Override + public PyType getType(@Nonnull TypeEvalContext context, @Nonnull TypeEvalContext.Key key) { + if (myTypeCallback != null) { + return myTypeCallback.apply(myContext); + } + else if (myClass != null) { + PyClassType type = PyPsiFacade.getInstance(getProject()).createClassType(myClass, !myResolveToInstance); + if (myCustomTypeInfo != null) { + myCustomTypeInfo.fill(type); + } + return type; + } + return null; } - return type; - } - return null; } - } } \ No newline at end of file diff --git a/python-psi-api/src/main/java/com/jetbrains/python/psi/PyFile.java b/python-psi-api/src/main/java/com/jetbrains/python/psi/PyFile.java index 38b7ead9..21101e24 100644 --- a/python-psi-api/src/main/java/com/jetbrains/python/psi/PyFile.java +++ b/python-psi-api/src/main/java/com/jetbrains/python/psi/PyFile.java @@ -20,106 +20,112 @@ import consulo.annotation.access.RequiredReadAction; import consulo.language.psi.PsiElement; import consulo.language.psi.PsiFile; -import consulo.language.version.LanguageVersion; import consulo.python.language.PythonLanguageVersion; import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; + import java.util.List; public interface PyFile extends PyElement, PsiFile, PyDocStringOwner, ScopeOwner { - List getStatements(); + @RequiredReadAction + List getStatements(); + + @RequiredReadAction + List getTopLevelClasses(); - List getTopLevelClasses(); + @Nonnull + @RequiredReadAction + List getTopLevelFunctions(); - @Nonnull - List getTopLevelFunctions(); + @RequiredReadAction + List getTopLevelAttributes(); - List getTopLevelAttributes(); + @Nullable + @RequiredReadAction + PyFunction findTopLevelFunction(String name); - @Nullable - PyFunction findTopLevelFunction(String name); + @Nullable + @RequiredReadAction + PyClass findTopLevelClass(String name); - @Nullable - PyClass findTopLevelClass(String name); + @Nullable + @RequiredReadAction + PyTargetExpression findTopLevelAttribute(String name); - @Nullable - PyTargetExpression findTopLevelAttribute(String name); + @Nonnull + @RequiredReadAction + default LanguageLevel getLanguageLevel() { + if (getLanguageVersion() instanceof PythonLanguageVersion pythonLanguageVersion) { + return pythonLanguageVersion.getLanguageLevel(); + } - @RequiredReadAction - @Nonnull - default LanguageLevel getLanguageLevel() { - LanguageVersion languageVersion = getLanguageVersion(); - if (languageVersion instanceof PythonLanguageVersion pythonLanguageVersion) { - return pythonLanguageVersion.getLanguageLevel(); + return LanguageLevel.getDefault(); } - return LanguageLevel.getDefault(); - } - - /** - * Return the list of all 'from ... import' statements in the top-level scope of the file. - * - * @return the list of 'from ... import' statements. - */ - @Nonnull - List getFromImports(); - - /** - * Return an exported PSI element defined in the file with the given name. - */ - @Nullable - PsiElement findExportedName(String name); - - /** - * Iterate over exported PSI elements defined in the file. - */ - @Nonnull - Iterable iterateNames(); - - /** - * Return the resolved exported elements. - */ - @Nonnull - List multiResolveName(@Nonnull String name); - - /** - * @deprecated Use {@link #multiResolveName(String)} instead. - */ - @Deprecated - @Nullable - PsiElement getElementNamed(String name); - - /** - * Returns the list of import elements in all 'import xxx' statements in the top-level scope of the file. - * - * @return the list of import targets. - */ - @Nonnull - List getImportTargets(); - - /** - * Returns the list of names in the __all__ declaration, or null if there is no such declaration in the module. - * - * @return the list of names or null. - */ - @Nullable - List getDunderAll(); - - /** - * Return true if the file contains a 'from __future__ import ...' statement with given feature. - */ - boolean hasImportFromFuture(FutureFeature feature); - - /** - * If the function raises a DeprecationWarning or a PendingDeprecationWarning, returns the explanation text provided for the warning. - * - * @return the deprecation message or null if the function is not deprecated. - */ - String getDeprecationMessage(); - - /** - * Returns the sequential list of import statements in the beginning of the file. - */ - List getImportBlock(); + /** + * Return the list of all 'from ... import' statements in the top-level scope of the file. + * + * @return the list of 'from ... import' statements. + */ + @Nonnull + List getFromImports(); + + /** + * Return an exported PSI element defined in the file with the given name. + */ + @Nullable + PsiElement findExportedName(String name); + + /** + * Iterate over exported PSI elements defined in the file. + */ + @Nonnull + Iterable iterateNames(); + + /** + * Return the resolved exported elements. + */ + @Nonnull + List multiResolveName(@Nonnull String name); + + /** + * @deprecated Use {@link #multiResolveName(String)} instead. + */ + @Deprecated + @Nullable + PsiElement getElementNamed(String name); + + /** + * Returns the list of import elements in all 'import xxx' statements in the top-level scope of the file. + * + * @return the list of import targets. + */ + @Nonnull + List getImportTargets(); + + /** + * Returns the list of names in the __all__ declaration, or null if there is no such declaration in the module. + * + * @return the list of names or null. + */ + @Nullable + List getDunderAll(); + + /** + * Return true if the file contains a 'from __future__ import ...' statement with given feature. + */ + boolean hasImportFromFuture(FutureFeature feature); + + /** + * If the function raises a DeprecationWarning or a PendingDeprecationWarning, returns the explanation text provided for the warning. + * + * @return the deprecation message or null if the function is not deprecated. + */ + String getDeprecationMessage(); + + /** + * Returns the sequential list of import statements in the beginning of the file. + */ + List getImportBlock(); } From 7e50ba94d2862bb20aecc1bb2fab75f95d3f5c49 Mon Sep 17 00:00:00 2001 From: UNV Date: Sat, 8 Nov 2025 17:45:10 +0300 Subject: [PATCH 2/2] Removed unnecessary PyBoundFunction.getNameIdentifier(). --- .../jetbrains/python/impl/psi/impl/PyBoundFunction.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java index dd4aae8c..c1954ea9 100644 --- a/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java +++ b/python-impl/src/main/java/com/jetbrains/python/impl/psi/impl/PyBoundFunction.java @@ -16,8 +16,6 @@ package com.jetbrains.python.impl.psi.impl; import com.jetbrains.python.psi.PyFunction; -import consulo.annotation.access.RequiredReadAction; -import consulo.language.psi.PsiElement; /** * @author yole @@ -26,10 +24,4 @@ public class PyBoundFunction extends PyFunctionImpl { public PyBoundFunction(PyFunction function) { super(function.getNode()); } - - @RequiredReadAction - @Override - public PsiElement getNameIdentifier() { - return super.getNameIdentifier(); - } }