Skip to content

Commit

Permalink
Extending functionality of ReplaceCallCastPass to `ResolveCallExpre…
Browse files Browse the repository at this point in the history
…ssionAmbiguityPass` (#1680)
  • Loading branch information
oxisto authored Sep 11, 2024
1 parent a5936b0 commit e871e2b
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 77 deletions.
6 changes: 6 additions & 0 deletions cpg-all/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ publishing {
}
}

repositories {
maven {
setUrl("https://jitpack.io")
}
}

dependencies {
// this exposes all of our (published) modules as dependency
api(projects.cpgConsole)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ private constructor(
registerPass<TypeResolver>()
registerPass<ControlFlowSensitiveDFGPass>()
registerPass<FilenameMapper>()
registerPass<ReplaceCallCastPass>()
registerPass<ResolveCallExpressionAmbiguityPass>()
useDefaultPasses = true
return this
}
Expand Down
17 changes: 6 additions & 11 deletions cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/TypeManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ package de.fraunhofer.aisec.cpg
import de.fraunhofer.aisec.cpg.frontends.CastNotPossible
import de.fraunhofer.aisec.cpg.frontends.CastResult
import de.fraunhofer.aisec.cpg.frontends.Language
import de.fraunhofer.aisec.cpg.graph.Name
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.TemplateDeclaration
import de.fraunhofer.aisec.cpg.graph.scopes.Scope
Expand Down Expand Up @@ -140,8 +139,7 @@ class TypeManager {
val node = scope.astNode

// We need an additional check here, because of parsing or other errors, the AST node
// might
// not necessarily be a template declaration.
// might not necessarily be a template declaration.
if (node is TemplateDeclaration) {
val parameterizedType = getTypeParameter(node, name)
if (parameterizedType != null) {
Expand Down Expand Up @@ -224,12 +222,9 @@ class TypeManager {
return t
}

fun typeExists(name: String): Boolean {
return firstOrderTypes.any { type: Type -> type.root.name.toString() == name }
}

fun typeExists(name: Name): Type? {
return firstOrderTypes.firstOrNull { type: Type -> type.root.name == name }
/** Checks, whether a [Type] with the given [name] exists. */
fun typeExists(name: CharSequence): Boolean {
return firstOrderTypes.any { type: Type -> type.root.name == name }
}

fun resolvePossibleTypedef(alias: Type, scopeManager: ScopeManager): Type {
Expand All @@ -243,7 +238,7 @@ class TypeManager {
* is [Type.Origin.RESOLVED].
*/
fun lookupResolvedType(
fqn: String,
fqn: CharSequence,
generics: List<Type>? = null,
language: Language<*>? = null
): Type? {
Expand All @@ -254,7 +249,7 @@ class TypeManager {

return firstOrderTypes.firstOrNull {
(it.typeOrigin == Type.Origin.RESOLVED || it.typeOrigin == Type.Origin.GUESSED) &&
it.name.toString() == fqn &&
it.root.name == fqn &&
if (generics != null) {
(it as? ObjectType)?.generics == generics
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ abstract class Language<T : LanguageFrontend<*, *>> : Node() {
* [builtInTypes] map, it returns null. The [typeString] must precisely match the key in the
* map.
*/
fun getSimpleTypeOf(typeString: String) = builtInTypes[typeString]
fun getSimpleTypeOf(typeString: CharSequence) = builtInTypes[typeString.toString()]

/** Returns true if the [file] can be handled by the frontend of this language. */
fun handlesFile(file: File): Boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,32 @@ interface HasAnonymousIdentifier : LanguageTrait {
*/
interface HasGlobalVariables : LanguageTrait

/**
* A common super-class for all language traits that arise because they are an ambiguity of a
* function call, e.g., function-style casts. This means that we cannot differentiate between a
* [CallExpression] and other expressions during the frontend and we need to invoke the
* [ResolveCallExpressionAmbiguityPass] to resolve this.
*/
sealed interface HasCallExpressionAmbiguity : LanguageTrait

/**
* A language trait, that specifies that the language has so-called functional style casts, meaning
* that they look like regular call expressions. Since we can therefore not distinguish between a
* [CallExpression] and a [CastExpression], we need to employ an additional pass
* ([ReplaceCallCastPass]) after the initial language frontends are done.
* ([ResolveCallExpressionAmbiguityPass]) after the initial language frontends are done.
*/
interface HasFunctionStyleCasts : HasCallExpressionAmbiguity

/**
* A language trait, that specifies that the language has functional style (object) construction,
* meaning that constructor calls look like regular call expressions (usually meaning that the
* language has no dedicated `new` keyword).
*
* Since we can therefore not distinguish between a [CallExpression] and a [ConstructExpression] in
* the frontend, we need to employ an additional pass ([ResolveCallExpressionAmbiguityPass]) after
* the initial language frontends are done.
*/
interface HasFunctionalCasts : LanguageTrait
interface HasFunctionStyleConstruction : HasCallExpressionAmbiguity

/**
* A language trait that specifies that this language allowed overloading functions, meaning that
Expand Down
16 changes: 9 additions & 7 deletions cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/graph/Name.kt
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ class Name(

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is Name) return false
if (other is String) return this.fullName == other
if (other is Name)
return localName == other.localName &&
parent == other.parent &&
delimiter == other.delimiter

return localName == other.localName &&
parent == other.parent &&
delimiter == other.delimiter
return false
}

override fun get(index: Int) = fullName[index]
Expand Down Expand Up @@ -153,9 +155,9 @@ internal fun parseName(fqn: CharSequence, delimiter: String, vararg splitDelimit
}

/** Returns a new [Name] based on the [localName] and the current name as parent. */
fun Name?.fqn(localName: String) =
fun Name?.fqn(localName: String, delimiter: String = this?.delimiter ?: ".") =
if (this == null) {
Name(localName)
Name(localName, null, delimiter)
} else {
Name(localName, this, this.delimiter)
Name(localName, this, delimiter)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,34 @@ package de.fraunhofer.aisec.cpg.passes

import de.fraunhofer.aisec.cpg.TranslationContext
import de.fraunhofer.aisec.cpg.frontends.Handler
import de.fraunhofer.aisec.cpg.frontends.HasFunctionalCasts
import de.fraunhofer.aisec.cpg.frontends.HasCallExpressionAmbiguity
import de.fraunhofer.aisec.cpg.frontends.HasFunctionStyleCasts
import de.fraunhofer.aisec.cpg.frontends.HasFunctionStyleConstruction
import de.fraunhofer.aisec.cpg.frontends.Language
import de.fraunhofer.aisec.cpg.frontends.LanguageFrontend
import de.fraunhofer.aisec.cpg.graph.*
import de.fraunhofer.aisec.cpg.graph.declarations.RecordDeclaration
import de.fraunhofer.aisec.cpg.graph.declarations.TranslationUnitDeclaration
import de.fraunhofer.aisec.cpg.graph.statements.expressions.*
import de.fraunhofer.aisec.cpg.graph.types.ObjectType
import de.fraunhofer.aisec.cpg.graph.types.Type
import de.fraunhofer.aisec.cpg.helpers.SubgraphWalker
import de.fraunhofer.aisec.cpg.passes.configuration.DependsOn
import de.fraunhofer.aisec.cpg.passes.configuration.ExecuteBefore
import de.fraunhofer.aisec.cpg.passes.configuration.RequiresLanguageTrait

/**
* If a [Language] has the trait [HasFunctionalCasts], we cannot distinguish between a
* [CallExpression] and a [CastExpression] during the initial translation. This stems from the fact
* that we might not know all the types yet. We therefore need to handle them as regular call
* expression in a [LanguageFrontend] or [Handler] and then later replace them with a
* [CastExpression], if the [CallExpression.callee] refers to name of a [Type] rather than a
* function.
* If a [Language] has the trait [HasCallExpressionAmbiguity], we cannot distinguish between
* [CallExpression], [CastExpression] or [ConstructExpression] during the initial translation. This
* stems from the fact that we might not know all the types yet. We therefore need to handle them as
* regular call expression in a [LanguageFrontend] or [Handler] and then later replace them with a
* [CastExpression] or [ConstructExpression], if the [CallExpression.callee] refers to name of a
* [Type] / [RecordDeclaration] rather than a function.
*/
@ExecuteBefore(EvaluationOrderGraphPass::class)
@DependsOn(TypeResolver::class)
@RequiresLanguageTrait(HasFunctionalCasts::class)
class ReplaceCallCastPass(ctx: TranslationContext) : TranslationUnitPass(ctx) {
@RequiresLanguageTrait(HasCallExpressionAmbiguity::class)
class ResolveCallExpressionAmbiguityPass(ctx: TranslationContext) : TranslationUnitPass(ctx) {
private lateinit var walker: SubgraphWalker.ScopedWalker

override fun accept(tu: TranslationUnitDeclaration) {
Expand All @@ -71,13 +75,44 @@ class ReplaceCallCastPass(ctx: TranslationContext) : TranslationUnitPass(ctx) {
return
}

// We really need a parent, otherwise we cannot replace the node
if (parent == null) {
return
}

// Some local copies for easier smart casting
var callee = call.callee
val language = callee.language

// Check, if this is cast is really a construct expression (if the language supports
// functional-constructs)
if (language is HasFunctionStyleConstruction) {
// Make sure, we do not accidentally "construct" primitive types
if (language.builtInTypes.contains(callee.name.toString()) == true) {
return
}

val fqn =
if (callee.name.parent == null) {
scopeManager.currentNamespace.fqn(
callee.name.localName,
delimiter = callee.name.delimiter
)
} else {
callee.name
}

// Check for our type. We are only interested in object types
val type = typeManager.lookupResolvedType(fqn)
if (type is ObjectType) {
walker.replaceCallWithConstruct(type, parent, call)
}
}

// We need to check, whether the "callee" refers to a type and if yes, convert it into a
// cast expression. And this is only really necessary, if the function call has a single
// argument.
var callee = call.callee
if (parent != null && call.arguments.size == 1) {
val language = parent.language

if (language is HasFunctionStyleCasts && call.arguments.size == 1) {
var pointer = false
// If the argument is a UnaryOperator, unwrap them
if (callee is UnaryOperator && callee.operatorCode == "*") {
Expand All @@ -86,20 +121,25 @@ class ReplaceCallCastPass(ctx: TranslationContext) : TranslationUnitPass(ctx) {
}

// First, check if this is a built-in type
if (language?.builtInTypes?.contains(callee.name.toString()) == true) {
walker.replaceCallWithCast(callee.name.toString(), parent, call, false)
var builtInType = language.getSimpleTypeOf(callee.name)
if (builtInType != null) {
walker.replaceCallWithCast(builtInType, parent, call, false)
} else {
// If not, then this could still refer to an existing type. We need to make sure
// that we take the current namespace into account
val fqn =
if (callee.name.parent == null) {
scopeManager.currentNamespace.fqn(callee.name.localName)
scopeManager.currentNamespace.fqn(
callee.name.localName,
delimiter = callee.name.delimiter
)
} else {
callee.name
}

if (typeManager.typeExists(fqn.toString())) {
walker.replaceCallWithCast(fqn, parent, call, pointer)
val type = typeManager.lookupResolvedType(fqn)
if (type != null) {
walker.replaceCallWithCast(type, parent, call, pointer)
}
}
}
Expand All @@ -112,7 +152,7 @@ class ReplaceCallCastPass(ctx: TranslationContext) : TranslationUnitPass(ctx) {

context(ContextProvider)
fun SubgraphWalker.ScopedWalker.replaceCallWithCast(
typeName: CharSequence,
type: Type,
parent: Node,
call: CallExpression,
pointer: Boolean,
Expand All @@ -123,16 +163,34 @@ fun SubgraphWalker.ScopedWalker.replaceCallWithCast(
cast.location = call.location
cast.castType =
if (pointer) {
call.objectType(typeName).pointer()
type.pointer()
} else {
call.objectType(typeName)
type
}
cast.expression = call.arguments.single()
cast.name = cast.castType.name

replaceArgument(parent, call, cast)
}

context(ContextProvider)
fun SubgraphWalker.ScopedWalker.replaceCallWithConstruct(
type: ObjectType,
parent: Node,
call: CallExpression
) {
val construct = newConstructExpression()
construct.code = call.code
construct.language = call.language
construct.location = call.location
construct.callee = call.callee
(construct.callee as? Reference)?.resolutionHelper = construct
construct.arguments = call.arguments
construct.type = type

replaceArgument(parent, call, construct)
}

context(ContextProvider)
fun SubgraphWalker.ScopedWalker.replaceArgument(parent: Node?, old: Expression, new: Expression) {
if (parent !is ArgumentHolder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ fun TranslationContext.tryRecordInference(
// At this point, we need to check whether we have any type reference to our parent
// name. If we have (e.g. it is used in a function parameter, variable, etc.), then we
// have a high chance that this is actually a parent record and not a namespace
var parentType = typeManager.typeExists(parentName)
var parentType = typeManager.lookupResolvedType(parentName)
holder =
if (parentType != null) {
tryRecordInference(parentType, locationHint = locationHint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ open class CPPLanguage :
HasStructs,
HasClasses,
HasUnknownType,
HasFunctionalCasts,
HasFunctionStyleCasts,
HasFunctionOverloading,
HasOperatorOverloading {
override val fileExtensions = listOf("cpp", "cc", "cxx", "c++", "hpp", "hh")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import de.fraunhofer.aisec.cpg.passes.configuration.ExecuteBefore
* type information.
*/
@ExecuteBefore(EvaluationOrderGraphPass::class)
@ExecuteBefore(ReplaceCallCastPass::class)
@ExecuteBefore(ResolveCallExpressionAmbiguityPass::class)
@DependsOn(TypeResolver::class)
class CXXExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {

Expand Down Expand Up @@ -76,7 +76,7 @@ class CXXExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {
* the graph.
*/
private fun removeBracketOperators(node: UnaryOperator, parent: Node?) {
if (node.operatorCode == "()" && !typeManager.typeExists(node.input.name.toString())) {
if (node.operatorCode == "()" && !typeManager.typeExists(node.input.name)) {
// It was really just parenthesis around an identifier, but we can only make this
// distinction now.
//
Expand All @@ -92,9 +92,10 @@ class CXXExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {
* operator where some arguments are wrapped in parentheses. This function tries to resolve
* this.
*
* Note: This is done especially for the C++ frontend. [ReplaceCallCastPass.handleCall] handles
* the more general case (which also applies to C++), in which a cast and a call are
* indistinguishable and need to be resolved once all types are known.
* Note: This is done especially for the C++ frontend.
* [ResolveCallExpressionAmbiguityPass.handleCall] handles the more general case (which also
* applies to C++), in which a cast and a call are indistinguishable and need to be resolved
* once all types are known.
*/
private fun convertOperators(binOp: BinaryOperator, parent: Node?) {
val fakeUnaryOp = binOp.lhs
Expand All @@ -107,7 +108,7 @@ class CXXExtraPass(ctx: TranslationContext) : ComponentPass(ctx) {
language != null &&
fakeUnaryOp is UnaryOperator &&
fakeUnaryOp.operatorCode == "()" &&
typeManager.typeExists(fakeUnaryOp.input.name.toString())
typeManager.typeExists(fakeUnaryOp.input.name)
) {
// If the name (`long` in the example) is a type, then the unary operator (`(long)`)
// is really a cast and our binary operator is really a unary operator `&addr`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GoLanguage :
HasStructs,
HasFirstClassFunctions,
HasAnonymousIdentifier,
HasFunctionalCasts {
HasFunctionStyleCasts {
override val fileExtensions = listOf("go")
override val namespaceDelimiter = "."
@Transient override val frontend = GoLanguageFrontend::class
Expand Down
Loading

0 comments on commit e871e2b

Please sign in to comment.