Skip to content

Commit

Permalink
Use SpecialSig enum to make code more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
pawelprazak committed Mar 12, 2024
1 parent ecebdd5 commit 350500a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 87 deletions.
30 changes: 25 additions & 5 deletions core/src/main/scala/besom/internal/ProtobufUtil.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package besom.internal

import besom.internal.Constants.{SecretValueName, SpecialSecretSig, SpecialSigKey}
import besom.internal.Constants.*
import com.google.protobuf.struct.*
import com.google.protobuf.struct.Value.Kind
import com.google.protobuf.util.JsonFormat
Expand Down Expand Up @@ -50,15 +50,35 @@ object ProtobufUtil:
case Some(a) => a.asValue
case None => Null

given ToValue[SpecialSig] with
extension (s: SpecialSig) def asValue: Value = s.asString.asValue

given ToValue[SecretValue] with
extension (s: SecretValue)
def asValue: Value = Map(
SpecialSig.Key -> SpecialSig.SecretSig.asValue,
SecretValueName -> s.value
).asValue

extension (v: Value)
def asJsonString: Either[Throwable, String] = Try(printer.print(Value.toJavaProto(v))).toEither
def asJsonStringOrThrow: String = asJsonString.fold(t => throw Exception("Expected a JSON", t), identity)

def struct: Option[Struct] = v.kind.structValue
def asSecret: Value = SecretValue(v).asValue

def asSecret: Value = Map(
SpecialSigKey -> SpecialSecretSig.asValue,
SecretValueName -> v
).asValue
def withSpecialSignature[A](f: (Struct, SpecialSig) => A): Option[A] =
for
struct: Struct <- v.struct
sig: SpecialSig <- struct.specialSignature
yield f(struct, sig)

extension (s: Struct)
def specialSignatureString: Option[String] =
s.fields.get(SpecialSig.Key).flatMap(_.kind.stringValue)
def specialSignature: Option[SpecialSig] =
s.specialSignatureString.flatMap(SpecialSig.fromString)

end ProtobufUtil

case class SecretValue(value: Value)
175 changes: 95 additions & 80 deletions core/src/main/scala/besom/internal/codecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,34 @@ import scala.util.*

//noinspection ScalaFileName
object Constants:
final val SpecialSigKey = "4dabf18193072939515e22adb298388d"
final val SpecialAssetSig = "c44067f5952c0a294b673a41bacd8c17"
final val SpecialArchiveSig = "0def7320c3a5731c473e5ecbe6d01bc7"
final val SpecialSecretSig = "1b47061264138c4ac30d75fd1eb44270"
final val SpecialResourceSig = "5cf8f73096256a8f31e491e813e4eb8e"
/** Well-known signatures used in gRPC protocol, see sdk/go/common/resource/properties.go. */
enum SpecialSig(value: String):
/** Signature used to identify assets in maps in gRPC protocol */
case AssetSig extends SpecialSig("c44067f5952c0a294b673a41bacd8c17")

/** Signature used to identify archives in maps in gRPC protocol */
case ArchiveSig extends SpecialSig("0def7320c3a5731c473e5ecbe6d01bc7")

/** Signature used to identify secrets maps in gRPC protocol */
case SecretSig extends SpecialSig("1b47061264138c4ac30d75fd1eb44270")

/** Signature used to identify resources in maps in gRPC protocol */
case ResourceSig extends SpecialSig("5cf8f73096256a8f31e491e813e4eb8e")

/** @return the signature raw value */
def asString: String = value

object SpecialSig:
/** Signature used to encode type identity inside of a map in gRPC protocol.
*
* This is required when flattening into ordinary maps, like we do when performing serialization, to ensure recoverability of type
* identities later on.
*/
final val Key = "4dabf18193072939515e22adb298388d"

def fromString(s: String): Option[SpecialSig] = SpecialSig.values.find(_.asString == s)

end SpecialSig

/** Well-known sentinels used in gRPC protocol, see sdk/go/common/resource/plugin/rpc.go */

Expand Down Expand Up @@ -278,38 +301,37 @@ object Decoder extends DecoderInstancesLowPrio1:
decodeAsPossibleSecret(value, label).flatMap { odv =>
odv
.traverseValidatedResult { innerValue =>
extractSpecialStructSignature(innerValue) match
innerValue.struct.flatMap(_.specialSignature) match
case None => error(s"$label: Expected a special struct signature", label).invalidResult
case Some(specialSig) =>
if specialSig != Constants.SpecialResourceSig then
error(s"$label: Expected a special resource signature, got: '$specialSig'", label).invalidResult
else
val structValue = innerValue.getStructValue
structValue.fields
.get(Constants.ResourceUrnName)
.map(_.getStringValue)
.toValidatedResultOrError(
error(s"$label: Expected a resource urn in resource struct, not found", label)
)
.flatMap(urnString => URN.from(urnString).toEither.toValidatedResult)
.flatMap { urn =>
NonEmptyString(urn.resourceName) match
case None =>
error(s"$label: Expected a non-empty resource name in resource urn", label).invalidResult
case Some(resourceName) =>
val opts =
CustomResourceOptions(urn = urn) // triggers GetResource instead of RegisterResource
Context()
.readOrRegisterResource[R, EmptyArgs](urn.resourceType, resourceName, EmptyArgs(), opts)
.getData
.either
.map {
case Right(outpudDataOfR) => outpudDataOfR.valid
case Left(err) => err.invalid
}
.asValidatedResult
end match
}
case Some(Constants.SpecialSig.ResourceSig) =>
val structValue = innerValue.getStructValue
structValue.fields
.get(Constants.ResourceUrnName)
.map(_.getStringValue)
.toValidatedResultOrError(
error(s"$label: Expected a resource urn in resource struct, not found", label)
)
.flatMap(urnString => URN.from(urnString).toEither.toValidatedResult)
.flatMap { urn =>
NonEmptyString(urn.resourceName) match
case None =>
error(s"$label: Expected a non-empty resource name in resource urn", label).invalidResult
case Some(resourceName) =>
val opts =
CustomResourceOptions(urn = urn) // triggers GetResource instead of RegisterResource
Context()
.readOrRegisterResource[R, EmptyArgs](urn.resourceType, resourceName, EmptyArgs(), opts)
.getData
.either
.map {
case Right(outpudDataOfR) => outpudDataOfR.valid
case Left(err) => err.invalid
}
.asValidatedResult
end match
}
case Some(sig) =>
error(s"$label: Expected a special resource signature, got: '$sig'", label).invalidResult
}
.map(_.flatten)
.lmap(exception =>
Expand All @@ -336,21 +358,20 @@ object Decoder extends DecoderInstancesLowPrio1:
decodeAsPossibleSecret(value, label).flatMap { odv =>
odv
.traverseValidatedResult { innerValue =>
extractSpecialStructSignature(innerValue) match
innerValue.struct.flatMap(_.specialSignature) match
case None => error(s"$label: Expected a special struct signature", label).invalidResult
case Some(specialSig) =>
if specialSig != Constants.SpecialResourceSig then
error(s"$label: Expected a special resource signature, got: '$specialSig'", label).invalidResult
else
val structValue = innerValue.getStructValue
structValue.fields
.get(Constants.ResourceUrnName)
.map(_.getStringValue)
.toValidatedResultOrError(
error(s"$label: Expected a resource urn in resource struct, not found", label)
)
.flatMap(urnString => URN.from(urnString).toEither.toValidatedResult)
.map(urn => OutputData(DependencyResource(Output(urn))))
case Some(Constants.SpecialSig.ResourceSig) =>
val structValue = innerValue.getStructValue
structValue.fields
.get(Constants.ResourceUrnName)
.map(_.getStringValue)
.toValidatedResultOrError(
error(s"$label: Expected a resource urn in resource struct, not found", label)
)
.flatMap(urnString => URN.from(urnString).toEither.toValidatedResult)
.map(urn => OutputData(DependencyResource(Output(urn))))
case Some(sig) =>
error(s"$label: Expected a special resource signature, got: '$sig'", label).invalidResult
}
.map(_.flatten)
.lmap(exception =>
Expand All @@ -365,14 +386,14 @@ object Decoder extends DecoderInstancesLowPrio1:
override def mapping(value: Value, label: Label): Validated[DecodingError, DependencyResource] = ???

def assetArchiveDecoder[A](
specialSig: String,
specialSig: Constants.SpecialSig,
handle: Context ?=> (Label, Struct) => ValidatedResult[DecodingError, OutputData[A]]
): Decoder[A] = new Decoder[A]:
override def decode(value: Value, label: Label)(using Context): ValidatedResult[DecodingError, OutputData[A]] =
decodeAsPossibleSecret(value, label).flatMap { odv =>
odv
.traverseValidatedResult { innerValue =>
extractSpecialStructSignature(innerValue) match
innerValue.struct.flatMap(_.specialSignature) match
case None => error(s"$label: Expected a special struct signature", label).invalidResult
case Some(extractedSpecialSig) =>
if extractedSpecialSig != specialSig then
Expand All @@ -391,7 +412,7 @@ object Decoder extends DecoderInstancesLowPrio1:
override def mapping(value: Value, label: Label): Validated[DecodingError, A] = ???

given fileAssetDecoder: Decoder[FileAsset] = assetArchiveDecoder[FileAsset](
Constants.SpecialAssetSig,
Constants.SpecialSig.AssetSig,
(label, structValue) =>
structValue.fields
.get(Constants.AssetOrArchivePathName)
Expand All @@ -401,7 +422,7 @@ object Decoder extends DecoderInstancesLowPrio1:
)

given remoteAssetDecoder: Decoder[RemoteAsset] = assetArchiveDecoder[RemoteAsset](
Constants.SpecialAssetSig,
Constants.SpecialSig.AssetSig,
(label, structValue) =>
structValue.fields
.get(Constants.AssetOrArchiveUriName)
Expand All @@ -411,7 +432,7 @@ object Decoder extends DecoderInstancesLowPrio1:
)

given stringAssetDecoder: Decoder[StringAsset] = assetArchiveDecoder(
Constants.SpecialAssetSig,
Constants.SpecialSig.AssetSig,
(label, structValue) =>
structValue.fields
.get(Constants.AssetTextName)
Expand All @@ -421,7 +442,7 @@ object Decoder extends DecoderInstancesLowPrio1:
)

given fileArchiveDecoder: Decoder[FileArchive] = assetArchiveDecoder[FileArchive](
Constants.SpecialArchiveSig,
Constants.SpecialSig.ArchiveSig,
(label, structValue) =>
structValue.fields
.get(Constants.AssetOrArchivePathName)
Expand All @@ -431,7 +452,7 @@ object Decoder extends DecoderInstancesLowPrio1:
)

given remoteArchiveDecoder: Decoder[RemoteArchive] = assetArchiveDecoder[RemoteArchive](
Constants.SpecialArchiveSig,
Constants.SpecialSig.ArchiveSig,
(label, structValue) =>
structValue.fields
.get(Constants.AssetOrArchiveUriName)
Expand All @@ -442,7 +463,7 @@ object Decoder extends DecoderInstancesLowPrio1:

// noinspection NoTailRecursionAnnotation
given assetArchiveDecoder: Decoder[AssetArchive] = assetArchiveDecoder[AssetArchive](
Constants.SpecialArchiveSig,
Constants.SpecialSig.ArchiveSig,
(label, structValue) =>
val nested = structValue.fields
.get(Constants.ArchiveAssetsName)
Expand Down Expand Up @@ -639,8 +660,8 @@ trait DecoderHelpers:
override def mapping(value: Value, label: Label): Validated[DecodingError, A] = ???

def decodeAsPossibleSecret(value: Value, label: Label)(using Context): ValidatedResult[DecodingError, OutputData[Value]] =
extractSpecialStructSignature(value) match
case Some(sig) if sig == SpecialSecretSig =>
value.struct.flatMap(_.specialSignature) match
case Some(SpecialSig.SecretSig) =>
val innerValue = value.getStructValue.fields
.get(SecretValueName)
.map(ValidatedResult.valid)
Expand All @@ -652,14 +673,6 @@ trait DecoderHelpers:
ValidatedResult.valid(OutputData.unknown(isSecret = false))
else ValidatedResult.valid(OutputData(value))

def extractSpecialStructSignature(value: Value): Option[String] =
Iterator(value)
.filter(_.kind.isStructValue)
.flatMap(_.getStructValue.fields)
.filter((k, _) => k == SpecialSigKey)
.flatMap((_, v) => v.kind.stringValue)
.nextOption // TODO: log error if the signature is not recognized

def accumulatedOutputDatasOrErrors[A](
acc: ValidatedResult[DecodingError, Vector[OutputData[A]]],
elementValidatedResult: ValidatedResult[DecodingError, OutputData[A]],
Expand Down Expand Up @@ -752,11 +765,12 @@ object Encoder:
if ctx.featureSupport.keepResources then
outputURNEnc.encode(a.urn).flatMap { (urnResources, urnValue) =>
val fixedIdValue =
if idValue.kind.isStringValue && idValue.getStringValue == UnknownStringValue then Value(Kind.StringValue(""))
if idValue.kind.isStringValue && idValue.getStringValue == UnknownStringValue
then Value(Kind.StringValue(""))
else idValue

val result = Map(
SpecialSigKey -> SpecialResourceSig.asValue,
SpecialSig.Key -> SpecialSig.ResourceSig.asValue,
ResourceUrnName -> urnValue,
ResourceIdName -> fixedIdValue
)
Expand All @@ -774,7 +788,7 @@ object Encoder:
outputURNEnc.encode(a.urn).flatMap { (urnResources, urnValue) =>
if ctx.featureSupport.keepResources then
val result = Map(
SpecialSigKey -> SpecialResourceSig.asValue,
SpecialSig.Key -> SpecialSig.ResourceSig.asValue,
ResourceUrnName -> urnValue
)

Expand All @@ -790,7 +804,7 @@ object Encoder:
outputURNEnc.encode(a.urn).flatMap { (urnResources, urnValue) =>
if ctx.featureSupport.keepResources then
val result = Map(
SpecialSigKey -> SpecialResourceSig.asValue,
SpecialSig.Key -> SpecialSig.ResourceSig.asValue,
ResourceUrnName -> urnValue
)

Expand Down Expand Up @@ -858,12 +872,12 @@ object Encoder:
}

private def assetWrapper(key: String, value: Value): Value = Map(
Constants.SpecialSigKey -> Constants.SpecialAssetSig.asValue,
SpecialSig.Key -> SpecialSig.AssetSig.asValue,
key -> value
).asValue

private def archiveWrapper(key: String, value: Value): Value = Map(
Constants.SpecialSigKey -> Constants.SpecialArchiveSig.asValue,
SpecialSig.Key -> SpecialSig.ArchiveSig.asValue,
key -> value
).asValue

Expand Down Expand Up @@ -1004,12 +1018,13 @@ object Encoder:
case Left(a) => innerA.encode(a)
case Right(b) => innerB.encode(b)

def isEmptySecretValue(value: Value): Boolean =
value.kind.isStructValue &&
value.getStructValue.fields.contains(SpecialSigKey) &&
value.getStructValue.fields(SpecialSigKey).getStringValue == SpecialSecretSig &&
value.getStructValue.fields.contains(SecretValueName) &&
value.getStructValue.fields(SecretValueName).kind.isNullValue
private [internal] def isEmptySecretValue(value: Value): Boolean =
value.withSpecialSignature {
case (struct, SpecialSig.SecretSig) =>
struct.fields.get(SecretValueName).exists(_.kind.isNullValue)
case (_, _) => false
}.getOrElse(false)
end isEmptySecretValue

end Encoder

Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/besom/internal/DecoderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ class DecoderTest extends munit.FunSuite:

test("special struct signature can be extracted") {
val secretStructSample: Value = Map(
SpecialSigKey -> SpecialSecretSig.asValue
SpecialSig.Key -> SpecialSig.SecretSig.asValue
).asValue

assert(extractSpecialStructSignature(secretStructSample).get == SpecialSecretSig)
assert(secretStructSample.struct.flatMap(_.specialSignature).get == SpecialSig.SecretSig)
}

test("decode case class") {
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/besom/internal/EncoderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,24 @@ class RecurrentArgsTest extends munit.FunSuite with ValueAssertions:
assertEqualsValue(encoded, expected, encoded.toProtoString)
}
end RecurrentArgsTest

class InternalTest extends munit.FunSuite:
import ProtobufUtil.*

for isSecret <- List(true, false)
do {
test(s"isEmptySecretValue (isSecret: $isSecret)") {
val value = if isSecret then Null.asSecret else Null
assertEquals(isEmptySecretValue(value), isSecret)
}
}

test("SpecialSig from String") {
import Constants.SpecialSig
assertEquals(SpecialSig.fromString(SpecialSig.AssetSig.asString), Some(SpecialSig.AssetSig))
assertEquals(SpecialSig.fromString(SpecialSig.ArchiveSig.asString), Some(SpecialSig.ArchiveSig))
assertEquals(SpecialSig.fromString(SpecialSig.SecretSig.asString), Some(SpecialSig.SecretSig))
assertEquals(SpecialSig.fromString(SpecialSig.ResourceSig.asString), Some(SpecialSig.ResourceSig))
assertEquals(SpecialSig.fromString("wrong"), None)
}
end InternalTest

0 comments on commit 350500a

Please sign in to comment.