Skip to content

Commit

Permalink
Support Rust types by retrieving them from debug info (rust-lang#307)
Browse files Browse the repository at this point in the history
* Complete the prototype of Rust debug info parser

* Change the uses of TypeTree class to a more appropriate pattern

* Complete the Rust debug info parser for pointers and arrays

* Complete the support for structs

* Complete the debug info type parser for tuples and fix some bugs

* Add support for Vecs and Boxes

* Document the Rust debug info parsing code

* Wrap Rust type info parsing into an if-statement, so it won't be invoked when the Rust type option is switched off

* Add support for unions

* Add a regression test for rust f32 type

* Update rustf32.ll

* Reduce the rustf32.ll test case to the minimum

* Add build dir generated by Clion and .idea dir to .gitignore

* Add a regression test for rust f64 type

* Add regression tests for rust integer types

* Add a test case for the rust struct type

* Delete some unnecessary chars from f32 and i8's test cases

* Add test cases for the rust array type

* Add a test case for the rust Vec type

* Add a test case for the rust Box type

* Add test cases for rust ref types

* Add test cases for rust pointer types

* Fix a bug related with the union type

* Add a regression test for the rust union type

* Revert "Add build dir generated by Clion and .idea dir to .gitignore"

This reverts commit b08016cb93e8ccf5cde8034c03a5b7f2ba2a185b.

* Make the rust type parser's code compatible with LLVM version under 9

* Make the test cases compatible with LLVM version under 9

* Change some code format

Co-authored-by: William Moses <gh@wsmoses.com>
Co-authored-by: Manuel Drehwald <git@manuel.drehwald.info>
  • Loading branch information
3 people authored Dec 1, 2021
1 parent deba550 commit 779ac0f
Show file tree
Hide file tree
Showing 27 changed files with 1,655 additions and 2 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

list(APPEND ENZYME_SRC SCEV/ScalarEvolutionExpander.cpp)
list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp)
list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp)

if (${LLVM_VERSION_MAJOR} LESS 8)
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
Expand Down Expand Up @@ -68,7 +68,7 @@ endif()
if (${ENZYME_EXTERNAL_SHARED_LIB})
add_library( Enzyme-${LLVM_VERSION_MAJOR}
SHARED
${ENZYME_SRC}
${ENZYME_SRC}
)
target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM)
install(TARGETS Enzyme-${LLVM_VERSION_MAJOR}
Expand Down
197 changes: 197 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
//===- RustDebugInfo.cpp - Implementaion of Rust Debug Info Parser ---===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===-------------------------------------------------------------------===//
//
// This file implement the Rust debug info parsing function. It will get the
// description of types from debug info of an instruction and pass it to
// concrete functions according to the kind of a description and construct
// the type tree recursively.
//
//===-------------------------------------------------------------------===//
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/Support/CommandLine.h"

#include "RustDebugInfo.h"

TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL);

TypeTree parseDIType(DIBasicType &Type, Instruction &I, DataLayout &DL) {
std::string TypeName = Type.getName().str();
TypeTree Result;
if (TypeName == "f64") {
Result = TypeTree(Type::getDoubleTy(I.getContext())).Only(0);
} else if (TypeName == "f32") {
Result = TypeTree(Type::getFloatTy(I.getContext())).Only(0);
} else if (TypeName == "i8" || TypeName == "i16" || TypeName == "i32" ||
TypeName == "i64" || TypeName == "isize" || TypeName == "u8" ||
TypeName == "u16" || TypeName == "u32" || TypeName == "u64" ||
TypeName == "usize" || TypeName == "i128" || TypeName == "u128") {
Result = TypeTree(ConcreteType(BaseType::Integer)).Only(0);
} else {
Result = TypeTree(ConcreteType(BaseType::Unknown)).Only(0);
}
return Result;
}

TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) {
TypeTree Result;
if (Type.getTag() == dwarf::DW_TAG_array_type) {
#if LLVM_VERSION_MAJOR >= 9
DIType *SubType = Type.getBaseType();
#else
DIType *SubType = Type.getBaseType().resolve();
#endif
TypeTree SubTT = parseDIType(*SubType, I, DL);
size_t Align = Type.getAlignInBytes();
size_t SubSize = SubType->getSizeInBits() / 8;
size_t Size = Type.getSizeInBits() / 8;
DINodeArray Subranges = Type.getElements();
size_t pos = 0;
for (auto r : Subranges) {
DISubrange *Subrange = dyn_cast<DISubrange>(r);
if (auto Count = Subrange->getCount().get<ConstantInt *>()) {
int64_t count = Count->getSExtValue();
if (count == -1) {
break;
}
for (int64_t i = 0; i < count; i++) {
Result |= SubTT.ShiftIndices(DL, 0, Size, pos);
size_t tmp = pos + SubSize;
if (tmp % Align != 0) {
pos = (tmp / Align + 1) * Align;
} else {
pos = tmp;
}
}
} else {
assert(0 && "There shouldn't be non-constant-size arrays in Rust");
}
}
return Result;
} else if (Type.getTag() == dwarf::DW_TAG_structure_type ||
Type.getTag() == dwarf::DW_TAG_union_type) {
DINodeArray Elements = Type.getElements();
size_t Size = Type.getSizeInBits() / 8;
bool firstSubTT = true;
for (auto e : Elements) {
DIType *SubType = dyn_cast<DIDerivedType>(e);
assert(SubType->getTag() == dwarf::DW_TAG_member);
TypeTree SubTT = parseDIType(*SubType, I, DL);
size_t Offset = SubType->getOffsetInBits() / 8;
SubTT = SubTT.ShiftIndices(DL, 0, Size, Offset);
if (Type.getTag() == dwarf::DW_TAG_structure_type) {
Result |= SubTT;
} else {
if (firstSubTT) {
Result = SubTT;
} else {
Result &= SubTT;
}
}
if (firstSubTT) {
firstSubTT = !firstSubTT;
}
}
return Result;
} else {
assert(0 && "Composite types other than arrays, structs and unions are not "
"supported by Rust debug info parser");
}
}

TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) {
if (Type.getTag() == dwarf::DW_TAG_pointer_type) {
TypeTree Result(BaseType::Pointer);
#if LLVM_VERSION_MAJOR >= 9
DIType *SubType = Type.getBaseType();
#else
DIType *SubType = Type.getBaseType().resolve();
#endif
TypeTree SubTT = parseDIType(*SubType, I, DL);
if (isa<DIBasicType>(SubType)) {
Result |= SubTT.ShiftIndices(DL, 0, 1, -1);
} else {
Result |= SubTT;
}
return Result.Only(0);
} else if (Type.getTag() == dwarf::DW_TAG_member) {
#if LLVM_VERSION_MAJOR >= 9
DIType *SubType = Type.getBaseType();
#else
DIType *SubType = Type.getBaseType().resolve();
#endif
TypeTree Result = parseDIType(*SubType, I, DL);
return Result;
} else {
assert(0 && "Derived types other than pointers and members are not "
"supported by Rust debug info parser");
}
}

TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) {
if (Type.getSizeInBits() == 0) {
return TypeTree();
}

if (auto BT = dyn_cast<DIBasicType>(&Type)) {
return parseDIType(*BT, I, DL);
} else if (auto CT = dyn_cast<DICompositeType>(&Type)) {
return parseDIType(*CT, I, DL);
} else if (auto DT = dyn_cast<DIDerivedType>(&Type)) {
return parseDIType(*DT, I, DL);
} else {
assert(0 && "Types other than floating-points, integers, arrays, pointers, "
"slices, and structs are not supported by debug info parser");
}
}

bool isU8PointerType(DIType &type) {
if (type.getTag() == dwarf::DW_TAG_pointer_type) {
auto PTy = dyn_cast<DIDerivedType>(&type);
#if LLVM_VERSION_MAJOR >= 9
DIType *SubType = PTy->getBaseType();
#else
DIType *SubType = PTy->getBaseType().resolve();
#endif
if (auto BTy = dyn_cast<DIBasicType>(SubType)) {
std::string name = BTy->getName().str();
if (name == "u8") {
return true;
}
}
}
return false;
}

TypeTree parseDIType(DbgDeclareInst &I, DataLayout &DL) {
#if LLVM_VERSION_MAJOR >= 9
DIType *type = I.getVariable()->getType();
#else
DIType *type = I.getVariable()->getType().resolve();
#endif

// If the type is *u8, do nothing, since the underlying type of data pointed
// by a *u8 can be anything
if (isU8PointerType(*type)) {
return TypeTree();
}
TypeTree Result = parseDIType(*type, I, DL);
return Result;
}
40 changes: 40 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/RustDebugInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===- RustDebugInfo.h - Declaration of Rust Debug Info Parser -------===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===-------------------------------------------------------------------===//
//
// This file contains the declaration of the Rust debug info parsing function
// which parses the debug info appended to LLVM IR generated by rustc and
// extracts useful type info from it. The type info will be used to initialize
// the following type analysis.
//
//===-------------------------------------------------------------------===//
#ifndef ENZYME_RUSTDEBUGINFO_H
#define ENZYME_RUSTDEBUGINFO_H 1

#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"

using namespace llvm;

#include "TypeTree.h"

/// Construct the type tree from debug info of an instruction
TypeTree parseDIType(DbgDeclareInst &I, DataLayout &DL);

#endif // ENZYME_RUSTDEBUGINFO_H
26 changes: 26 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "../FunctionUtils.h"
#include "../LibraryFuncs.h"

#include "RustDebugInfo.h"
#include "TBAA.h"

extern "C" {
Expand Down Expand Up @@ -1645,6 +1646,10 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) {
Type *et1 = cast<PointerType>(I.getType())->getElementType();
Type *et2 = cast<PointerType>(I.getOperand(0)->getType())->getElementType();

TypeTree Debug = getAnalysis(I.getOperand(0)).Data0();
DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout();
TypeTree Debug1 = Debug.KeepForCast(DL, et2, et1);

if (direction & DOWN)
updateAnalysis(
&I,
Expand Down Expand Up @@ -4546,6 +4551,9 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) {
}

analysis.prepareArgs();
if (RustTypeRules) {
analysis.considerRustDebugInfo();
}
analysis.considerTBAA();
analysis.run();

Expand Down Expand Up @@ -4726,6 +4734,24 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val,
return dt;
}

/// Parse the debug info generated by rustc and retrieve useful type info if
/// possible
void TypeAnalyzer::considerRustDebugInfo() {
DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout();
for (BasicBlock &BB : *fntypeinfo.Function) {
for (Instruction &I : BB) {
if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(&I)) {
TypeTree TT = parseDIType(*DDI, DL);
if (!TT.isKnown()) {
continue;
}
TT |= TypeTree(BaseType::Pointer);
updateAnalysis(DDI->getAddress(), TT.Only(-1), DDI);
}
}
}
}

Function *TypeResults::getFunction() const {
return analyzer.fntypeinfo.Function;
}
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {
/// Analyze type info given by the TBAA, possibly adding to work queue
void considerTBAA();

/// Parse the debug info generated by rustc and retrieve useful type info if
/// possible
void considerRustDebugInfo();

/// Run the interprocedural type analysis starting from this function
void run();

Expand Down
75 changes: 75 additions & 0 deletions enzyme/test/TypeAnalysis/rust3darray.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
; RUN: %opt < %s %loadEnzyme -enzyme-rust-type -print-type-analysis -type-analysis-func=callee -o /dev/null | FileCheck %s



declare void @llvm.dbg.declare(metadata, metadata, metadata)

define internal void @callee(i8* %arg) !dbg !373 {
start:
%t = bitcast i8* %arg to [2 x [2 x [2 x float]]]*
call void @llvm.dbg.declare(metadata [2 x [2 x [2 x float]]]* %t, metadata !384, metadata !DIExpression()), !dbg !385
ret void
}

!llvm.module.flags = !{!14, !15, !16, !17}
!llvm.dbg.cu = !{!18}

!0 = !DIGlobalVariableExpression(var: !1, expr: !DIExpression())
!1 = distinct !DIGlobalVariable(name: "vtable", scope: null, file: !2, type: !3, isLocal: true, isDefinition: true)
!2 = !DIFile(filename: "<unknown>", directory: "")
!3 = !DICompositeType(tag: DW_TAG_structure_type, name: "vtable", file: !2, align: 64, flags: DIFlagArtificial, elements: !4, vtableHolder: !5, identifier: "vtable")
!4 = !{}
!5 = !DICompositeType(tag: DW_TAG_structure_type, name: "{closure#0}", scope: !6, file: !2, size: 64, align: 64, elements: !9, templateParams: !4, identifier: "c211ca2a5a4c8dd717d1e5fba4a6ae0")
!6 = !DINamespace(name: "lang_start", scope: !7)
!7 = !DINamespace(name: "rt", scope: !8)
!8 = !DINamespace(name: "std", scope: null)
!9 = !{!10}
!10 = !DIDerivedType(tag: DW_TAG_member, name: "main", scope: !5, file: !2, baseType: !11, size: 64, align: 64)
!11 = !DIDerivedType(tag: DW_TAG_pointer_type, name: "fn()", baseType: !12, size: 64, align: 64, dwarfAddressSpace: 0)
!12 = !DISubroutineType(types: !13)
!13 = !{null}
!14 = !{i32 7, !"PIC Level", i32 2}
!15 = !{i32 7, !"PIE Level", i32 2}
!16 = !{i32 2, !"RtLibUseGOT", i32 1}
!17 = !{i32 2, !"Debug Info Version", i32 3}
!18 = distinct !DICompileUnit(language: DW_LANG_Rust, file: !19, producer: "clang LLVM (rustc version 1.56.0 (09c42c458 2021-10-18))", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !20, globals: !37)
!19 = !DIFile(filename: "rust3darray.rs", directory: "/home/nomanous/Space/Tmp/Enzyme")
!20 = !{!21, !28}
!21 = !DICompositeType(tag: DW_TAG_enumeration_type, name: "Result", scope: !22, file: !2, baseType: !24, size: 8, align: 8, elements: !25)
!22 = !DINamespace(name: "result", scope: !23)
!23 = !DINamespace(name: "core", scope: null)
!24 = !DIBasicType(name: "u8", size: 8, encoding: DW_ATE_unsigned)
!25 = !{!26, !27}
!26 = !DIEnumerator(name: "Ok", value: 0)
!27 = !DIEnumerator(name: "Err", value: 1)
!28 = !DICompositeType(tag: DW_TAG_enumeration_type, name: "Alignment", scope: !29, file: !2, baseType: !24, size: 8, align: 8, elements: !32)
!29 = !DINamespace(name: "v1", scope: !30)
!30 = !DINamespace(name: "rt", scope: !31)
!31 = !DINamespace(name: "fmt", scope: !23)
!32 = !{!33, !34, !35, !36}
!33 = !DIEnumerator(name: "Left", value: 0)
!34 = !DIEnumerator(name: "Right", value: 1)
!35 = !DIEnumerator(name: "Center", value: 2)
!36 = !DIEnumerator(name: "Unknown", value: 3)
!37 = !{!0}
!156 = !DIBasicType(name: "f32", size: 32, encoding: DW_ATE_float)
!373 = distinct !DISubprogram(name: "callee", linkageName: "_ZN11rust3darray6callee17h37b114a70360ce19E", scope: !375, file: !374, line: 1, type: !376, scopeLine: 1, flags: DIFlagPrototyped, unit: !18, templateParams: !4, retainedNodes: !383)
!374 = !DIFile(filename: "rust3darray.rs", directory: "/home/nomanous/Space/Tmp/Enzyme", checksumkind: CSK_MD5, checksum: "adf66a1fcb26c178e41abd9c50aa582a")
!375 = !DINamespace(name: "rust3darray", scope: null)
!376 = !DISubroutineType(types: !377)
!377 = !{!156, !378}
!378 = !DICompositeType(tag: DW_TAG_array_type, baseType: !379, size: 256, align: 32, elements: !381)
!379 = !DICompositeType(tag: DW_TAG_array_type, baseType: !380, size: 128, align: 32, elements: !381)
!380 = !DICompositeType(tag: DW_TAG_array_type, baseType: !156, size: 64, align: 32, elements: !381)
!381 = !{!382}
!382 = !DISubrange(count: 2, lowerBound: 0)
!383 = !{!384}
!384 = !DILocalVariable(name: "t", arg: 1, scope: !373, file: !374, line: 1, type: !378)
!385 = !DILocation(line: 1, column: 11, scope: !373)

; CHECK: callee - {} |{}:{}
; CHECK-NEXT: i8* %arg: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float, [-1,12]:Float@float, [-1,16]:Float@float, [-1,20]:Float@float, [-1,24]:Float@float, [-1,28]:Float@float}
; CHECK-NEXT: start
; CHECK-NEXT: %t = bitcast i8* %arg to [2 x [2 x [2 x float]]]*: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float, [-1,12]:Float@float, [-1,16]:Float@float, [-1,20]:Float@float, [-1,24]:Float@float, [-1,28]:Float@float}
; CHECK-NEXT: call void @llvm.dbg.declare(metadata [2 x [2 x [2 x float]]]* %t, metadata !50, metadata !DIExpression()), !dbg !51: {}
; CHECK-NEXT: ret void: {}
Loading

0 comments on commit 779ac0f

Please sign in to comment.