Skip to content

Commit

Permalink
fix: warnings
Browse files Browse the repository at this point in the history
SIR version (1, 0)
  • Loading branch information
nau committed Jan 6, 2025
1 parent abb166b commit 849a720
Show file tree
Hide file tree
Showing 15 changed files with 31 additions and 152 deletions.
2 changes: 0 additions & 2 deletions jvm/src/test/scala/scalus/examples/PreimageExampleSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ import scalus.ledger.api.v2.*
import scalus.ledger.api.v1.ToDataInstances.given
import scalus.ledger.api.v2.ToDataInstances.given
import scalus.prelude.List
import scalus.sir.PrettyPrinter
import scalus.sir.RemoveRecursivity.removeRecursivity
import scalus.sir.SIR
import scalus.sir.SIRType
import scalus.uplc.*
import scalus.uplc.Term.*
import scalus.uplc.TermDSL.given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.util.SrcPos
import scalus.SIRTypesHelper.{sirTypeInEnvWithErr, unsupportedType, TypingException}
import scalus.SIRTypesHelper.TypingException
import scalus.sir.*

import scala.collection.mutable
import scala.language.implicitConversions
import scala.util.control.NonFatal

enum SirBinding:
case Name(name: String, tp: SIRType)
Expand Down Expand Up @@ -395,8 +394,8 @@ class PatternMatchingCompiler(val compiler: SIRCompiler)(using Context) {
.map(_ => bindingName.fresh().show)
// TODO: extract rhs to a let binding before the match
// so we don't have to repeat it for each case
// also we have no way toknowtype-arameters, so use abstract type-vars (will use FreeUnificator))
val typeArgs = constr.typeParams.map(tp => SIRType.FreeUnificator)
// also we have no way to know type-arguments, so use abstract type-vars (will use FreeUnificator)
val typeArgs = constr.typeParams.map(_ => SIRType.FreeUnificator)
val constrDecl = retrieveConstrDecl(env, constr, srcPos)
expandedCases += constructCase(constrDecl, bindings, typeArgs, rhs)
matchedConstructors += constr // collect all matched constructors
Expand Down
60 changes: 22 additions & 38 deletions scalus-plugin/src/main/scala/scalus/SIRCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import scala.collection.mutable.ListBuffer
import scala.language.implicitConversions
import scala.annotation.unused
import scala.util.control.NonFatal
import scalus.builtin.ByteString

case class FullName(name: String)
object FullName:
Expand Down Expand Up @@ -82,18 +81,21 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
import tpd.*
import SIRCompiler.Env

val SirVersion = (0, 0)
val SirVersion = (1, 0)

private val converter = new SIRConverter
private val builtinsHelper = new BuiltinHelper
private val BigIntSymbol = requiredModule("scala.math.BigInt")
private val BigIntClassSymbol = requiredClass("scala.math.BigInt")
private val ByteStringClassSymbol = requiredClass("scalus.builtin.ByteString")
private val DataClassSymbol = requiredClass("scalus.builtin.Data")
private val PairSymbol = requiredClass("scalus.builtin.Pair")
private val ScalusBuiltinListClassSymbol = requiredClass("scalus.builtin.List")
private val PlatformSpecificClassSymbol = requiredClass("scalus.builtin.PlatformSpecific")
private val StringContextSymbol = requiredModule("scala.StringContext")
private val StringContextApplySymbol = StringContextSymbol.requiredMethod("apply")
private val Tuple2Symbol = requiredClass("scala.Tuple2")
private val NothingSymbol = requiredClass("scala.Nothing")
private val ByteStringModuleSymbol = converter.ByteStringSymbol
private val ByteStringModuleSymbol = requiredModule("scalus.builtin.ByteString")
private val ByteStringSymbolHex = ByteStringModuleSymbol.requiredMethod("hex")
private val ByteStringStringInterpolatorsMethodSymbol =
ByteStringModuleSymbol.requiredMethod("StringInterpolators")
Expand All @@ -117,7 +119,7 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
extension (t: Tree) def isPair: Boolean = t.tpe.isPair

extension (t: Tree) def isLiteral = compileConstant.isDefinedAt(t)
extension (t: Tree) def isData = t.tpe <:< converter.DataClassSymbol.typeRef
extension (t: Tree) def isData = t.tpe <:< DataClassSymbol.typeRef

enum CompileDef:
case Compiling
Expand Down Expand Up @@ -209,24 +211,7 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
val bitSize = fl.bitSize(module)
val enc = EncoderState(bitSize / 8 + 1)
fl.encode(module, enc)
try {
enc.filler()
} catch {
case ex: ArrayIndexOutOfBoundsException =>
println("Catched ArrayIndexOutOfBoundsException during encoding in filler()")
println(
s"module: ${module.defs(0).name}, number of defs: ${module.defs.size}, bitSize: ${bitSize}"
)

val hs0 = scalus.utils.HashConsed.State.empty
val hsc1 = scalus.utils.HashConsedEncoderState.withSize(1000)
for b <- module.defs do {
println(s"def: ${b}")
// val bindingBitSize = scalus.flat.FlatInstantces.BindingFlat.bitSizeHC(b,hs0)
}

throw ex
}
enc.filler()
output.write(enc.buffer)
output.close()
}
Expand Down Expand Up @@ -739,28 +724,27 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
case Constants.UnitTag => scalus.uplc.Constant.Unit
case _ => error(LiteralTypeNotSupported(c, l.srcPos), scalus.uplc.Constant.Unit)
case t @ Apply(bigintApply, List(SkipInline(Literal(c))))
if bigintApply.symbol == converter.BigIntSymbol.requiredMethod(
if bigintApply.symbol == BigIntSymbol.requiredMethod(
"apply",
List(defn.StringClass.typeRef)
) && c.tag == Constants.StringTag =>
scalus.uplc.Constant.Integer(BigInt(c.stringValue))
case t @ Apply(bigintApply, List(SkipInline(Literal(c))))
if bigintApply.symbol == converter.BigIntSymbol.requiredMethod(
if bigintApply.symbol == BigIntSymbol.requiredMethod(
"apply",
List(defn.IntType)
) && c.tag == Constants.IntTag =>
scalus.uplc.Constant.Integer(BigInt(c.intValue))

case Apply(i, List(Literal(c)))
if i.symbol == converter.BigIntSymbol.requiredMethod("int2bigInt") =>
case Apply(i, List(Literal(c))) if i.symbol == BigIntSymbol.requiredMethod("int2bigInt") =>
scalus.uplc.Constant.Integer(BigInt(c.intValue))
case expr if expr.symbol == converter.ByteStringSymbol.requiredMethod("empty") =>
case expr if expr.symbol == ByteStringModuleSymbol.requiredMethod("empty") =>
scalus.uplc.Constant.ByteString(scalus.builtin.ByteString.empty)
case Apply(expr, List(Literal(c)))
if expr.symbol == converter.ByteStringSymbol.requiredMethod("fromHex") =>
if expr.symbol == ByteStringModuleSymbol.requiredMethod("fromHex") =>
scalus.uplc.Constant.ByteString(scalus.builtin.ByteString.fromHex(c.stringValue))
case Apply(expr, List(Literal(c)))
if expr.symbol == converter.ByteStringSymbol.requiredMethod("fromString") =>
if expr.symbol == ByteStringModuleSymbol.requiredMethod("fromString") =>
scalus.uplc.Constant.ByteString(scalus.builtin.ByteString.fromString(c.stringValue))
// hex"deadbeef" as ByteString using Scala 3 StringContext extension
case Apply(
Expand Down Expand Up @@ -796,12 +780,12 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
}

private def typeReprToDefaultUni(tpe: Type, list: Tree): DefaultUni =
if tpe =:= converter.BigIntClassSymbol.typeRef then DefaultUni.Integer
if tpe =:= BigIntClassSymbol.typeRef then DefaultUni.Integer
else if tpe =:= defn.StringClass.typeRef then DefaultUni.String
else if tpe =:= defn.BooleanClass.typeRef then DefaultUni.Bool
else if tpe =:= defn.UnitClass.typeRef then DefaultUni.Unit
else if tpe =:= converter.DataClassSymbol.typeRef then DefaultUni.Data
else if tpe =:= converter.ByteStringClassSymbol.typeRef then DefaultUni.ByteString
else if tpe =:= DataClassSymbol.typeRef then DefaultUni.Data
else if tpe =:= ByteStringClassSymbol.typeRef then DefaultUni.ByteString
else if tpe.isPair then
val List(t1, t2) = tpe.dealias.argInfos
DefaultUni.Pair(typeReprToDefaultUni(t1, list), typeReprToDefaultUni(t2, list))
Expand Down Expand Up @@ -1214,8 +1198,8 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
)
// ByteString equality
case Apply(Select(lhs, op), List(rhs))
if lhs.tpe.widen =:= converter.ByteStringClassSymbol.typeRef && (op == nme.EQ || op == nme.NE) =>
if !(rhs.tpe.widen =:= converter.ByteStringClassSymbol.typeRef) then
if lhs.tpe.widen =:= ByteStringClassSymbol.typeRef && (op == nme.EQ || op == nme.NE) =>
if !(rhs.tpe.widen =:= ByteStringClassSymbol.typeRef) then
report.error(
s"""Equality is only allowed between the same types but here we have ${lhs.tpe.widen.show} == ${rhs.tpe.widen.show}
|Make sure you compare values of the same type""".stripMargin
Expand Down Expand Up @@ -1252,7 +1236,7 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
if op == nme.EQ then eq else SIR.Not(eq)
// Data equality
case Apply(Select(lhs, op), List(rhs))
if lhs.tpe.widen <:< converter.DataClassSymbol.typeRef && (op == nme.EQ || op == nme.NE) =>
if lhs.tpe.widen <:< DataClassSymbol.typeRef && (op == nme.EQ || op == nme.NE) =>
val lhsExpr = compileExpr(env, lhs)
val rhsExpr = compileExpr(env, rhs)
val eq =
Expand All @@ -1274,10 +1258,10 @@ final class SIRCompiler(mode: scalus.Mode)(using ctx: Context) {
builtinsHelper.builtinFun(bi.symbol).get
// BigInt stuff
case Apply(optree @ Select(lhs, op), List(rhs))
if lhs.tpe.widen =:= converter.BigIntClassSymbol.typeRef =>
if lhs.tpe.widen =:= BigIntClassSymbol.typeRef =>
compileBigIntOps(env, lhs, op, rhs, optree)
case Select(expr, op)
if expr.tpe.widen =:= converter.BigIntClassSymbol.typeRef && op == nme.UNARY_- =>
if expr.tpe.widen =:= BigIntClassSymbol.typeRef && op == nme.UNARY_- =>
SIR.Apply(
SIR.Apply(
SIRBuiltins.subtractInteger,
Expand Down
78 changes: 2 additions & 76 deletions scalus-plugin/src/main/scala/scalus/SIRConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,22 @@ package scalus

import dotty.tools.dotc.*
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.*
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.NameKinds.UniqueNameKind
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.MethodType
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.core.Types.TypeRef
import dotty.tools.dotc.core.*
import dotty.tools.dotc.util.Spans
import scalus.builtin.ByteString
import scalus.builtin.Data
import scalus.flat.FlatInstantces
import scalus.sir.Binding
import scalus.sir.ConstrDecl
import scalus.sir.DataDecl
import scalus.sir.Recursivity
import scalus.sir.SIR
import scalus.sir.SIR.Case
import scalus.sir.SIRType
import scalus.uplc.DefaultFun
import scalus.uplc.DefaultUni
import scalus.utils.HashConsed
import scalus.utils.HashConsedEncoderState
import scalus.utils.HashConsedDecoderState
import scalus.utils.HashConsedFlat

import scala.collection.immutable
import scala.language.implicitConversions
import scalus.builtin.BLS12_381_G1_Element
import scalus.builtin.BLS12_381_G2_Element

class SIRConverter(using Context) {
import tpd.*

val ErrorSymbol = requiredModule("scalus.sir.SIR.Error")
val ConstSymbol = requiredModule("scalus.sir.SIR.Const")
val ApplySymbol = requiredModule("scalus.sir.SIR.Apply")
val BigIntSymbol = requiredModule("scala.math.BigInt")
val BigIntClassSymbol = requiredClass("scala.math.BigInt")
val DataConstrSymbol = requiredModule("scalus.builtin.Data.Constr")
val DataMapSymbol = requiredModule("scalus.builtin.Data.Map")
val DataListSymbol = requiredModule("scalus.builtin.Data.List")
val DataISymbol = requiredModule("scalus.builtin.Data.I")
val DataBSymbol = requiredModule("scalus.builtin.Data.B")
val ConstantClassSymbol = requiredClass("scalus.uplc.Constant")
val ConstantIntegerSymbol = requiredModule("scalus.uplc.Constant.Integer")
val ConstantBoolSymbol = requiredModule("scalus.uplc.Constant.Bool")
val ConstantUnitSymbol = requiredModule("scalus.uplc.Constant.Unit")
val ConstantStringSymbol = requiredModule("scalus.uplc.Constant.String")
val ConstantByteStringSymbol = requiredModule("scalus.uplc.Constant.ByteString")
val ConstantDataSymbol = requiredModule("scalus.uplc.Constant.Data")
val ConstantListSymbol = requiredModule("scalus.uplc.Constant.List")
val ConstantPairSymbol = requiredModule("scalus.uplc.Constant.Pair")
val ConstantBLS12_381_G1_ElementSymbol = requiredModule(
"scalus.uplc.Constant.BLS12_381_G1_Element"
)
val ConstantBLS12_381_G2_ElementSymbol = requiredModule(
"scalus.uplc.Constant.BLS12_381_G2_Element"
)
val ByteStringSymbol = requiredModule("scalus.builtin.ByteString")
val ByteStringClassSymbol = requiredClass("scalus.builtin.ByteString")
val BLS12_381_G1_ElementSymbol = requiredModule("scalus.builtin.BLS12_381_G1_Element")
val BLS12_381_G2_ElementSymbol = requiredModule("scalus.builtin.BLS12_381_G2_Element")
val VarSymbol = requiredModule("scalus.sir.SIR.Var")
val ExternalVarSymbol = requiredModule("scalus.sir.SIR.ExternalVar")
val LetSymbol = requiredModule("scalus.sir.SIR.Let")
val LamAbsSymbol = requiredModule("scalus.sir.SIR.LamAbs")
val NamedDeBruijnSymbol = requiredModule("scalus.uplc.NamedDeBruijn")
val BuiltinSymbol = requiredModule("scalus.sir.SIR.Builtin")
val BindingSymbol = requiredModule("scalus.sir.Binding")
val BindingClassSymbol = requiredClass("scalus.sir.Binding")
val MatchSymbol = requiredModule("scalus.sir.SIR.Match")
val CaseSymbol = requiredModule("scalus.sir.Case")
val CaseClassSymbol = requiredClass("scalus.sir.Case")
val ConstrDeclSymbol = requiredModule("scalus.sir.ConstrDecl")
val ConstrDeclClassSymbol = requiredClass("scalus.sir.ConstrDecl")
val ConstrSymbol = requiredModule("scalus.sir.SIR.Constr")
val SIRClassSymbol = requiredClass("scalus.sir.SIR")
val DataClassSymbol = requiredClass("scalus.builtin.Data")
val DataDeclSymbol = requiredModule("scalus.sir.DataDecl")
val DeclSymbol = requiredModule("scalus.sir.SIR.Decl")
val IfThenElseSymbol = requiredModule("scalus.sir.SIR.IfThenElse")
val AndSymbol = requiredModule("scalus.sir.SIR.And")
val OrSymbol = requiredModule("scalus.sir.SIR.Or")
val NotSymbol = requiredModule("scalus.sir.SIR.Not")

def convertViaSerialization(sir: SIR, span: Spans.Span): Tree = {
private def convertViaSerialization(sir: SIR, span: Spans.Span): Tree = {
val bitSize =
scalus.flat.FlatInstantces.SIRHashConsedFlat.bitSizeHC(sir, HashConsed.State.empty)
val encodedState = HashConsedEncoderState.withSize(bitSize)
Expand Down
1 change: 0 additions & 1 deletion scalus-plugin/src/main/scala/scalus/SIRTypesHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package scalus

import dotty.tools.dotc.*
import dotty.tools.dotc.core.*
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Types.*
Expand Down
9 changes: 0 additions & 9 deletions shared/src/main/scala/scalus/sir/FlatInstances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import scalus.flat.EncoderState
import scalus.flat.Flat
import scalus.flat.given
import scalus.sir.{Binding, ConstrDecl, DataDecl, Module, Recursivity, SIR, SIRBuiltins, SIRType, SIRVarStorage, TypeBinding}
import scalus.sir.SIR.Case
import scalus.uplc.CommonFlatInstances.*
import scalus.uplc.CommonFlatInstances.given
import scalus.builtin.Data
Expand Down Expand Up @@ -268,8 +267,6 @@ object FlatInstantces:
val retval = a match
case _: SIRType.Primitive[?] =>
tagWidth
case SIRType.Data =>
tagWidth
case cc: SIRType.CaseClass =>
tagWidth + SIRTypeCaseClassFlat.bitSizeHC(cc, hashConsed)
case scc: SIRType.SumCaseClass =>
Expand All @@ -296,12 +293,6 @@ object FlatInstantces:
case err: SIRType.TypeError =>
tagWidth + summon[Flat[String]].bitSize(err.msg)
case SIRType.TypeNothing => tagWidth
case SIRType.BLS12_381_G1_Element =>
tagWidth
case SIRType.BLS12_381_G2_Element =>
tagWidth
case SIRType.BLS12_381_MlResult =>
tagWidth

if !mute then
println(s"SIRTypeHashConsedFlat.bisSizeHC end ${a.hashCode()} $a =${retval}")
Expand Down
5 changes: 1 addition & 4 deletions shared/src/main/scala/scalus/sir/PrettyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,7 @@ object PrettyPrinter:
text(decl.name) + inOptBrackets(
intercalate(text(",") + space, typeParams.map(pretty))
)
case SIRType.Data => text("Data")
case SIRType.BLS12_381_G1_Element => text("BLS12_381_G1_Element")
case SIRType.BLS12_381_G2_Element => text("BLS12_381_G2_Element")
case SIRType.BLS12_381_MlResult => text("BLS12_381_MlResult")
case t => text(t.show)

def pretty(p: Program): Doc =
val (major, minor, patch) = p.version
Expand Down
3 changes: 0 additions & 3 deletions shared/src/main/scala/scalus/sir/SIR.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package scalus.sir

import scalus.sir.SIRType.{checkAllProxiesFilled, FreeUnificator}
import scalus.uplc.Constant
import scalus.uplc.DefaultFun

import scala.util.hashing.MurmurHash3

case class Module(version: (Int, Int), defs: List[Binding])

case class Binding(name: String, value: SIR) {
Expand Down
8 changes: 0 additions & 8 deletions shared/src/main/scala/scalus/sir/SIRBuiltins.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
package scalus.sir

import scalus.builtin.ByteString
import scalus.builtin.Data
import scalus.builtin.BLS12_381_G1_Element
import scalus.builtin.BLS12_381_G2_Element
import scalus.builtin.BLS12_381_MlResult
import scalus.uplc.DefaultFun
import scalus.uplc.TypeScheme.Type

import scala.util.control.NonFatal

object SIRBuiltins {

Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/scalus/sir/SIRToExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import scalus.flat.FlatInstantces.SIRTypeHashConsedFlat
import scalus.flat.FlatInstantces.SIRHashConsedFlat

import scala.quoted.*
import scalus.sir.SIRType.{checkAllProxiesFilled, TypeVar}
import scalus.sir.SIRType.checkAllProxiesFilled
import scalus.utils.*
import scalus.flat.*

Expand Down
1 change: 0 additions & 1 deletion shared/src/main/scala/scalus/sir/SIRType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scalus.sir
import scalus.uplc.DefaultUni
import scalus.uplc.TypeScheme as UplcTypeScheme
import scalus.sir.SIRType.TypeVar
import scalus.uplc.DefaultUni

import java.util

Expand Down
1 change: 0 additions & 1 deletion shared/src/main/scala/scalus/uplc/FlatInstantces.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import scalus.flat
import scalus.flat.{DecoderState, EncoderState, Flat, Natural, given}
import scalus.uplc.CommonFlatInstances.*
import scalus.uplc.CommonFlatInstances.given
import scalus.utils.*

object FlatInstantces:
val termTagWidth = 4
Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/scalus/utils/HashConsedFlat.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scalus.utils

import scalus.flat.{*, given}
import scalus.flat.*

class HashConsedEncoderState(val encode: EncoderState, val hashConsed: HashConsed.State) {

Expand Down
Loading

0 comments on commit 849a720

Please sign in to comment.