Skip to content

Commit

Permalink
refactor(compiler): clean statistic passes
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 29, 2023
1 parent 9e8c44e commit 67246b0
Show file tree
Hide file tree
Showing 17 changed files with 599 additions and 578 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.

#ifndef CONCRETELANG_ANALYSIS_UTILS_H
#define CONCRETELANG_ANALYSIS_UTILS_H

#include <boost/outcome.h>
#include <concretelang/Common/Error.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Location.h>

namespace mlir {
namespace concretelang {

/// Get the string representation of a location
std::string locationString(mlir::Location loc);

/// Compute the number of iterations based on loop info
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step);

/// Compute the number of iterations of an scf for loop
outcome::checked<int64_t, ::concretelang::error::StringError>
calculateNumberOfIterations(scf::ForOp &op);

} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
#define CONCRETELANG_DIALECT_CONCRETE_ANALYSIS

include "mlir/Pass/PassBase.td"

def MemoryUsage : Pass<"MemoryUsage", "::mlir::ModuleOp"> {
let summary = "Compute memory usage";
let description = [{
Computes memory usage per location, and provides those numbers throught the CompilationFeedback.
}];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangConcreteAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangConcreteAnalysisPassIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,16 @@
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>

namespace mlir {
namespace concretelang {
namespace Concrete {

struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback);

CompilationFeedback &feedback;

MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;

size_t iterations = 1;
};

} // namespace Concrete
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef CONCRETELANG_DIALECT_TFHE_ANALYSIS
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS

include "mlir/Pass/PassBase.td"

def ExtractStatistics : Pass<"ExtractStatistics", "::mlir::ModuleOp"> {
let summary = "Extracts statistics";
let description = [{
Extracts different statistics (e.g. number of certain crypto operations),
and provides those numbers throught the CompilationFeedback.
}];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangTFHEAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangTFHEAnalysisPassIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,15 @@
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS_EXTRACT_STATISTICS_H

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>

#include <concretelang/Support/CompilationFeedback.h>

namespace mlir {
namespace concretelang {
namespace TFHE {

struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {

CompilationFeedback &feedback;

ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
: feedback{feedback} {};

void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (walk.wasInterrupted()) {
signalPassFailure();
}
}

std::optional<StringError> enter(Operation *op);

std::optional<StringError> exit(Operation *op);

size_t iterations = 1;
};

} // namespace TFHE
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback);
} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_mlir_library(
AnalysisUtils
Utils.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
)
67 changes: 67 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <concretelang/Analysis/Utils.h>
#include <mlir/Dialect/Arith/IR/Arith.h>

using ::concretelang::error::StringError;

namespace mlir {
namespace concretelang {
std::string locationString(mlir::Location loc) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
loc->print(locationStream);
return location;
}

int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;

if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}

if (low >= high) {
return 0;
}

return ((high - low - 1) / step) + 1;
}

outcome::checked<int64_t, StringError>
calculateNumberOfIterations(scf::ForOp &op) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();

auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());

if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}

auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();

if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}

int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();

return calculateNumberOfIterations(start, stop, step);
}
} // namespace concretelang
} // namespace mlir
1 change: 1 addition & 0 deletions compilers/concrete-compiler/compiler/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Transforms)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ add_mlir_library(
LINK_LIBS
PUBLIC
MLIRIR
ConcreteDialect)
ConcreteDialect
AnalysisUtils)
Loading

0 comments on commit 67246b0

Please sign in to comment.