Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Jupyter compile-time DF type not recognized #401

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,96 @@ internal class Integration(

val version = options["v"]

private fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

private fun KotlinKernelHost.execute(
codeWithConverter: CodeWithConverter,
property: KProperty<*>,
type: String,
): VariableName? {
val variableName = "(${property.name}${if (property.returnType.isMarkedNullable) "!!" else ""} as $type)"
return execute(codeWithConverter, variableName)
}

private fun KotlinKernelHost.updateImportDataSchemaVariable(
importDataSchema: ImportDataSchema,
property: KProperty<*>,
): VariableName? {
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
return when (
val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)
) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
}
}

private fun KotlinKernelHost.updateAnyFrameVariable(
df: AnyFrame,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(df, property),
property = property,
type = "AnyFrame",
)

private fun KotlinKernelHost.updateAnyRowVariable(
row: AnyRow,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(row, property),
property = property,
type = "AnyRow",
)

private fun KotlinKernelHost.updateColumnGroupVariable(
col: ColumnGroup<*>,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(col.asDataFrame(), property),
property = property,
type = "ColumnGroup<*>",
)

private fun KotlinKernelHost.updateAnyColVariable(
col: AnyCol,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(codeWithConverter = codeWithConverter, property = property, type = "AnyCol")
} else {
null
}

override fun Builder.onLoaded() {
if (version != null) {
dependencies(
Expand Down Expand Up @@ -152,65 +242,17 @@ internal class Integration(
import("org.jetbrains.kotlinx.dataframe.dataTypes.*")
import("org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader")

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, property: KProperty<*>): VariableName? {
val variableName = property.name + if (property.returnType.isMarkedNullable) "!!" else ""
return execute(codeWithConverter, variableName)
}

updateVariable<ImportDataSchema> { importDataSchema, property ->
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
when (val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
updateVariable<Any> { instance, property ->
when (instance) {
is AnyCol -> updateAnyColVariable(instance, property, codeGen)
is ColumnGroup<*> -> updateColumnGroupVariable(instance, property, codeGen)
is AnyRow -> updateAnyRowVariable(instance, property, codeGen)
is AnyFrame -> updateAnyFrameVariable(instance, property, codeGen)
is ImportDataSchema -> updateImportDataSchemaVariable(instance, property)
else -> null
}
}

updateVariable<AnyFrame> { df, property ->
execute(codeGen.process(df, property), property)
}

updateVariable<AnyRow> { row, property ->
execute(codeGen.process(row, property), property)
}

updateVariable<ColumnGroup<*>> { col, property ->
execute(codeGen.process(col.asDataFrame(), property), property)
}

updateVariable<AnyCol> { col, property ->
if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(codeWithConverter, property)
} else null
}

fun KotlinKernelHost.addDataSchemas(classes: List<KClass<*>>) {
val code = classes.joinToString("\n") {
codeGen.process(it)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dataframe.jupyter

import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.jupyter.api.Code
import org.junit.Test

Expand All @@ -11,9 +12,20 @@ class CodeGenerationTests : DataFrameJupyterTest() {
}
}

@Test
fun `Type erased dataframe`() {
@Language("kts")
val a = """
fun create(): Any? = dataFrameOf("a")(1)
val df = create()
df.a
""".checkCompilation()
}

@Test
fun `nullable dataframe`() {
"""
@Language("kts")
val a = """
fun create(): AnyFrame? = dataFrameOf("a")(1)
val df = create()
df.a
Expand All @@ -22,7 +34,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable columnGroup`() {
"""
@Language("kts")
val a = """
fun create(): AnyCol? = dataFrameOf("a")(1).asColumnGroup().asDataColumn()
val col = create()
col.a
Expand All @@ -31,7 +44,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable dataRow`() {
"""
@Language("kts")
val a = """
fun create(): AnyRow? = dataFrameOf("a")(1).single()
val row = create()
row.a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
res1 shouldBe Unit

@Language("kts")
val res2 = execRaw("df") as AnyFrame

res2["value"].type shouldBe typeOf<List<Any?>>()
Expand Down Expand Up @@ -126,6 +128,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {
)
res1 shouldBe Unit

@Language("kts")
val res2 = execRaw("df.`1`")
res2.shouldBeInstanceOf<ValueColumn<*>>()
}
Expand All @@ -141,6 +144,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {
)
res1.shouldBeInstanceOf<MimeTypedResult>()

@Language("kts")
val res2 = exec(
"""listOf(df.`{a}`[0], df.`(b)`[0], df.`{c}`[0])"""
)
Expand All @@ -157,6 +161,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
res1.shouldBeInstanceOf<MimeTypedResult>()

@Language("kts")
val res2 = exec(
"listOf(df.`\$id`[0])"
)
Expand All @@ -174,6 +180,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
)
res1.shouldBeInstanceOf<MimeTypedResult>()
println(res1.entries.joinToString())

@Language("kts")
val res2 = exec(
"listOf(df.`Day's`[0])"
)
Expand All @@ -193,6 +201,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
)
res1.shouldBeInstanceOf<MimeTypedResult>()
println(res1.entries.joinToString())

@Language("kts")
val res2 = exec(
"listOf(df.`Test `[0])"
)
Expand All @@ -212,6 +222,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
)
res1.shouldBeInstanceOf<MimeTypedResult>()
println(res1.entries.joinToString())

@Language("kts")
val res2 = exec(
"listOf(df.`Test `[0])"
)
Expand All @@ -220,6 +232,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {

@Test
fun `generic interface`() {
@Language("kts")
val res1 = exec(
"""
@DataSchema
Expand All @@ -229,6 +242,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
res1.shouldBeInstanceOf<Unit>()

@Language("kts")
val res2 = exec(
"""
val <T> ColumnsContainer<Generic<T>>.test1: DataColumn<T> get() = field
Expand All @@ -240,6 +255,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {

@Test
fun `generic interface with upper bound`() {
@Language("kts")
val res1 = exec(
"""
@DataSchema
Expand All @@ -249,6 +265,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
res1.shouldBeInstanceOf<Unit>()

@Language("kts")
val res2 = exec(
"""
val <T : String> ColumnsContainer<Generic<T>>.test1: DataColumn<T> get() = field
Expand All @@ -260,6 +278,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {

@Test
fun `generic interface with variance and user type in type parameters`() {
@Language("kts")
val res1 = exec(
"""
interface UpperBound
Expand All @@ -271,6 +290,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
res1.shouldBeInstanceOf<Unit>()

@Language("kts")
val res2 = exec(
"""
val <T : UpperBound> ColumnsContainer<Generic<T>>.test1: DataColumn<T> get() = field
Expand All @@ -282,7 +303,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {

@Test
fun `generate a new marker when dataframe marker is not a data schema so that columns are accessible with extensions`() {
exec(
@Language("kts")
val a = exec(
"""
enum class State {
Idle, Productive, Maintenance
Expand All @@ -305,7 +327,8 @@ class JupyterCodegenTests : JupyterReplTestCase() {
""".trimIndent()
)
shouldNotThrowAny {
exec(
@Language("kts")
val b = exec(
"""
events.toolId
events.state
Expand Down
Loading