Skip to content

Commit

Permalink
Merge pull request #286 from koxudaxi/improve_handling_pydantic_version
Browse files Browse the repository at this point in the history
Improve handling pydantic version
  • Loading branch information
koxudaxi authored May 3, 2021
2 parents e6fe821 + 59fcf10 commit a2bb9cd
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 31 deletions.
3 changes: 3 additions & 0 deletions resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<h2>version 0.3.1</h2>
<p>Features</p>
<ul>
<li>Improve handling pydantic version [#286]</li>
<li>Support config parameters on class kwargs [#285]</li>
</ul>
<h2>version 0.3.0</h2>
Expand Down Expand Up @@ -327,6 +328,8 @@
id="pydanticTypedValidatorMethodHandler" order="before pyMethodNameTypedHandler"/>
<projectService
serviceImplementation="com.koxudaxi.pydantic.PydanticConfigService"/>
<projectService
serviceImplementation="com.koxudaxi.pydantic.PydanticVersionService"/>

<projectConfigurable groupId="tools" instance="com.koxudaxi.pydantic.PydanticConfigurable"/>
<postStartupActivity implementation="com.koxudaxi.pydantic.PydanticInitializer" order="last"/>
Expand Down
21 changes: 1 addition & 20 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -300,25 +300,6 @@ fun getPsiElementByQualifiedName(
return qualifiedName.resolveToElement(QNameResolveContext(contextAnchor, pythonSdk, context))
}

fun getPydanticVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression
?: return null
val versionString =
(version.findAssignedValue()?.lastChild?.firstChild?.nextSibling as? PyStringLiteralExpression)?.stringValue
?: (version.findAssignedValue() as? PyStringLiteralExpressionImpl)?.stringValue ?: return null
return pydanticVersionCache.getOrPut(versionString) {
val versionList = versionString.split(VERSION_SPLIT_PATTERN).map { it.toIntOrNull() ?: 0 }
val pydanticVersion = when {
versionList.size == 1 -> KotlinVersion(versionList[0], 0)
versionList.size == 2 -> KotlinVersion(versionList[0], versionList[1])
versionList.size >= 3 -> KotlinVersion(versionList[0], versionList[1], versionList[2])
else -> null
} ?: KotlinVersion(0, 0)
pydanticVersionCache[versionString] = pydanticVersion
pydanticVersion
}
}

fun isValidField(field: PyTargetExpression, context: TypeEvalContext): Boolean {
if (!isValidFieldName(field.name)) return false

Expand Down Expand Up @@ -369,7 +350,7 @@ fun validateConfig(pyClass: PyClass): List<PsiElement>? {

fun getConfig(pyClass: PyClass, context: TypeEvalContext, setDefault: Boolean, pydanticVersion: KotlinVersion? = null): HashMap<String, Any?> {
val config = hashMapOf<String, Any?>()
val version = pydanticVersion ?: getPydanticVersion(pyClass.project, context)
val version = pydanticVersion ?: PydanticVersionService.getVersion(pyClass.project, context)
pyClass.getAncestorClasses(context)
.reversed()
.filter { isPydanticModel(it, false) }
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticCompletionContributor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PydanticCompletionContributor : CompletionContributor() {
config: HashMap<String, Any?>,
excludes: HashSet<String>?,
isDataclass: Boolean) {
val pydanticVersion = getPydanticVersion(pyClass.project, typeEvalContext)
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, typeEvalContext)
getClassVariables(pyClass, typeEvalContext)
.filter { it.name != null }
.filterNot { isUntouchedClass(it.findAssignedValue(), config, typeEvalContext) }
Expand Down
2 changes: 1 addition & 1 deletion src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class PydanticInspection : PyInspection() {
}

private fun inspectConfig(pyClass: PyClass) {
val pydanticVersion = getPydanticVersion(pyClass.project, myTypeEvalContext)
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, myTypeEvalContext)
if (pydanticVersion?.isAtLeast(1, 8) != true) return
if (!isPydanticModel(pyClass, false, myTypeEvalContext)) return
validateConfig(pyClass)?.forEach {
Expand Down
26 changes: 20 additions & 6 deletions src/com/koxudaxi/pydantic/PydanticPackageManagerListener.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,37 @@ package com.koxudaxi.pydantic
import com.intellij.openapi.Disposable
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.project.ProjectManager
import com.intellij.openapi.projectRoots.Sdk
import com.intellij.openapi.util.Disposer
import com.jetbrains.python.packaging.PyPackageManager
import com.jetbrains.python.sdk.PythonSdkUtil
import com.jetbrains.python.statistics.sdks

class PydanticPackageManagerListener : PyPackageManager.Listener {
private fun clearVersion(sdk: Sdk) {
ProjectManager.getInstance().openProjects
.filter { it.sdks.contains(sdk) }
.forEach { PydanticVersionService.clear(it) }
}

override fun packagesRefreshed(sdk: Sdk) {
ApplicationManager.getApplication().invokeLater {
if (sdk is Disposable && Disposer.isDisposed(sdk)) {
return@invokeLater
}
PythonSdkUtil.findSkeletonsDir(sdk)?.let { skeletons ->
skeletons.findChild("pydantic")?.let { pydanticStub ->
runWriteAction {
try {
pydanticStub.delete(null)
} catch (e: java.io.IOException) {
val skeletons = PythonSdkUtil.findSkeletonsDir(sdk)
val pydanticStub = skeletons?.findChild("pydantic")
if (pydanticStub == null) {
clearVersion(sdk)
} else {
runWriteAction {
try {
pydanticStub.delete(this)
} catch (e: java.io.IOException) {
} finally {
pydanticStub.refresh(true, true) {
clearVersion(sdk)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
private fun getRefTypeFromFieldName(name: String, context: TypeEvalContext, pyClass: PyClass): Ref<PyType>? {
val ellipsis = PyElementGenerator.getInstance(pyClass.project).createEllipsis()

val pydanticVersion = getPydanticVersion(pyClass.project, context)
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context)
return getRefTypeFromFieldNameInPyClass(name, pyClass, context, ellipsis, pydanticVersion)
?: pyClass.getAncestorClasses(context)
.filter { isPydanticModel(it, false, context) }
Expand Down Expand Up @@ -208,7 +208,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
): PydanticDynamicModelClassType? {
val project = pyFunction.project
val typed = getInstance(project).currentInitTyped
val pydanticVersion = getPydanticVersion(pyFunction.project, context)
val pydanticVersion = PydanticVersionService.getVersion(pyFunction.project, context)
val collected = linkedMapOf<String, PydanticDynamicModel.Attribute>()
val newVersion = pydanticVersion == null || pydanticVersion.isAtLeast(1, 5)
val modelNameParameterName = if (newVersion) "__model_name" else "model_name"
Expand Down Expand Up @@ -341,7 +341,7 @@ class PydanticTypeProvider : PyTypeProviderBase() {
}
}

val pydanticVersion = getPydanticVersion(pyClass.project, context)
val pydanticVersion = PydanticVersionService.getVersion(pyClass.project, context)
val config = getConfig(pyClass, context, true)
for (currentType in StreamEx.of(clsType).append(pyClass.getAncestorTypes(context))) {
if (currentType !is PyClassType) continue
Expand Down
52 changes: 52 additions & 0 deletions src/com/koxudaxi/pydantic/PydanticVersionService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.koxudaxi.pydantic

import com.intellij.openapi.components.ServiceManager
import com.intellij.openapi.project.Project
import com.jetbrains.python.psi.PyStringLiteralExpression
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.impl.PyStringLiteralExpressionImpl
import com.jetbrains.python.psi.types.TypeEvalContext

class PydanticVersionService {
private var version: KotlinVersion? = null

private fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
val version = getPsiElementByQualifiedName(VERSION_QUALIFIED_NAME, project, context) as? PyTargetExpression
?: return null
val versionString =
(version.findAssignedValue()?.lastChild?.firstChild?.nextSibling as? PyStringLiteralExpression)?.stringValue
?: (version.findAssignedValue() as? PyStringLiteralExpressionImpl)?.stringValue ?: return null
return pydanticVersionCache.getOrPut(versionString) {
val versionList = versionString.split(VERSION_SPLIT_PATTERN).map { it.toIntOrNull() ?: 0 }
val pydanticVersion = when {
versionList.size == 1 -> KotlinVersion(versionList[0], 0)
versionList.size == 2 -> KotlinVersion(versionList[0], versionList[1])
versionList.size >= 3 -> KotlinVersion(versionList[0], versionList[1], versionList[2])
else -> null
} ?: KotlinVersion(0, 0)
pydanticVersionCache[versionString] = pydanticVersion
pydanticVersion
}
}

private fun getOrPutVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
if (version != null) return version
return getVersion(project, context).apply { version = this }
}
private fun clear() {
version = null
}

companion object {
fun getVersion(project: Project, context: TypeEvalContext): KotlinVersion? {
return getInstance(project).getOrPutVersion(project, context)
}
fun clear(project: Project) {
return getInstance(project).clear()
}
private fun getInstance(project: Project): PydanticVersionService {
return ServiceManager.getService(project, PydanticVersionService::class.java)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package com.koxudaxi.pydantic

import com.intellij.openapi.application.invokeLater
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.components.ServiceManager
import com.intellij.openapi.progress.util.BackgroundTaskUtil
import com.intellij.openapi.vfs.VirtualFile
import com.jetbrains.python.packaging.PyPackageManager
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.sdk.PythonSdkUtil
import junit.framework.TestCase


open class PydanticPackageManagerListenerTest : PydanticTestCase() {
Expand All @@ -22,4 +25,22 @@ open class PydanticPackageManagerListenerTest : PydanticTestCase() {
assertFalse(pydanticStubDir!!.exists())
}
}

fun testClearVersion() {
val project = myFixture!!.project
val context = TypeEvalContext.userInitiated(project, null)
val sdk = PythonSdkUtil.findPythonSdk(myFixture!!.module)!!

val pydanticVersion = PydanticVersionService.getVersion(project, context)
assertEquals(KotlinVersion(1, 0,1), pydanticVersion)

BackgroundTaskUtil.syncPublisher(project, PyPackageManager.PACKAGE_MANAGER_TOPIC).packagesRefreshed(sdk)
invokeLater {
val privateVersionField = PydanticVersionService::class.java.getDeclaredField("version")
privateVersionField.trySetAccessible()
val pydanticVersionService = ServiceManager.getService(project, PydanticVersionService::class.java)
val actual = privateVersionField.get(pydanticVersionService)
assertNull(actual)
}
}
}

0 comments on commit a2bb9cd

Please sign in to comment.