Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #757

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def _libpsi():
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240524.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz",
],
strip_prefix = "psi-0.4.0.dev240524",
sha256 = "c2868fa6a9d804e6bbed9922dab6dc819ec6e180e15eafe7eb1b661302508c88",
strip_prefix = "psi-0.4.0beta",
sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -169,10 +169,10 @@ def _com_github_pybind11():
http_archive,
name = "pybind11",
build_file = "@pybind11_bazel//:pybind11.BUILD",
sha256 = "bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7",
strip_prefix = "pybind11-2.12.0",
sha256 = "51631e88960a8856f9c497027f55c9f2f9115cafb08c0005439838a05ba17bfc",
strip_prefix = "pybind11-2.13.1",
urls = [
"https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz",
"https://github.com/pybind/pybind11/archive/refs/tags/v2.13.1.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion docs/reference/np_op_status.json

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions docs/reference/np_op_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,20 @@ Please check *Supported Dtypes* as well.
- uint16
- uint32

## bitwise_count

JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_count.html
### Status

**PASS**
Please check *Supported Dtypes* as well.
### Supported Dtypes

- int16
- int32
- uint16
- uint32

## bitwise_not

JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html
Expand Down
6 changes: 3 additions & 3 deletions docs/reference/pphlo_doc.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
PPHlo API reference
PPHLO API reference
===================

PPHlo is short for (SPU High level ops), it's the assembly language of SPU.
PPHLO is short for (Privacy-Preserving High-Level Operations), it's the assembly language of SPU.

PPHlo is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here <spu/blob/main/libspu/dialect/pphlo/IR/ops.td>`.
PPHLO is built on `MLIR <https://mlir.llvm.org/>`_ infrastructure, the concrete ops definition could be found :spu_code_host:`here <spu/blob/main/libspu/dialect/pphlo/IR/ops.td>`.

Op List
~~~~~~~
Expand Down
58 changes: 38 additions & 20 deletions docs/reference/pphlo_op_doc.md
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ Ref https://www.tensorflow.org/xla/operation_semantics#dot.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

Expand Down Expand Up @@ -1626,55 +1626,63 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values values

### `pphlo.power` (spu::pphlo::PowOp)
### `pphlo.popcnt` (spu::pphlo::PopcntOp)

_Power operator_
_Popcnt operator, ties away from zero_


Syntax:

```
operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
`:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result))
operation ::= `pphlo.popcnt` $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
```

Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.
Performs element-wise count of the number of bits set in the `operand` tensor and produces a `result` tensor.

Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`

Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>bits</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `operand` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values

#### Results:

| Result | Description |
| :----: | ----------- |
| `result` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `result` | statically shaped tensor of 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values values

### `pphlo.prefer_a` (spu::pphlo::PreferAOp)
### `pphlo.power` (spu::pphlo::PowOp)

_Prefer AShare operator_
_Power operator_


Syntax:

```
operation ::= `pphlo.prefer_a` $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
operation ::= `pphlo.power` $lhs `,` $rhs attr-dict
`:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result))
```

Convert input to AShare if possible.
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and produces a `result` tensor.

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`
Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`

Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Expand All @@ -1684,7 +1692,8 @@ Effects: `MemoryEffects::Effect{}`

| Operand | Description |
| :-----: | ----------- |
| `operand` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `lhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values
| `rhs` | statically shaped tensor of pred (AKA boolean or 1-bit integer) or Secret of 1-bit signless integer values or 8/16/32/64-bit signless integer or Secret of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or 8/16/32/64-bit unsigned integer or Secret of 8/16/32/64-bit unsigned integer values or 16-bit float or 32-bit float or 64-bit float or Secret of 16-bit float or 32-bit float or 64-bit float values or complex type with 32-bit float or 64-bit float elements or Secret of complex type with 32-bit float or 64-bit float elements values values

#### Results:

Expand Down Expand Up @@ -2270,12 +2279,21 @@ Returns the sign of the `operand` element-wise and produces a `result` tensor.

Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign

PPHLO Extension: when `ignore_zero` is set to true, sign does not enforce sign(0) to 0

Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultShape`, `SameOperandsAndResultType`

Interfaces: `ConditionallySpeculatable`, `InferShapedTypeOpInterface`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>ignore_zero</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand Down Expand Up @@ -2377,7 +2395,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice

Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

Expand Down Expand Up @@ -2551,7 +2569,7 @@ Ref https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose

Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultElementType`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

Expand Down
3 changes: 2 additions & 1 deletion docs/reference/runtime_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ The SPU runtime configuration.
| Field | Type | Description |
| ----- | ---- | ----------- |
| server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. |
| session_id | [ string](#string) | if empty, use link id as session id. |
| adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. |
| asym_crypto_schema | [ string](#string) | asym_crypto_schema: support ["SM2"] Will support 25519 in the future, after yacl supported it. |
| server_public_key | [ bytes](#bytes) | server's public key |
<!-- end Fields -->
<!-- end HasFields -->

Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/common/compilation_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace {

void SPUErrorHandler(void * /*use_data*/, const char *reason,
bool /*gen_crash_diag*/) {
SPU_THROW(reason);
SPU_THROW("{}", reason);
}

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/common/ir_printer_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void IRPrinterConfig::printBeforeIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
return;
}
print_callback(f);
}
Expand All @@ -64,6 +65,7 @@ void IRPrinterConfig::printAfterIfEnabled(Pass *pass, Operation *,
if (ec.value() != 0) {
spdlog::error("Open file {} failed, error = {}", file_name.c_str(),
ec.message());
return;
}
print_callback(f);
}
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/fe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ mlir::OwningOpRef<mlir::ModuleOp> FE::doit(const CompilationSource &source) {
module = mlir::parseSourceString<mlir::ModuleOp>(source.ir_txt(),
ctx_->getMLIRContext());

SPU_ENFORCE(module, "MLIR parser failure");

// Convert stablehlo to mhlo first
mlir::PassManager pm(ctx_->getMLIRContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
Expand Down
6 changes: 3 additions & 3 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
auto module_config =
xla::HloModule::CreateModuleConfigFromProto(hlo_module, debug_options);
if (!module_config.status().ok()) {
SPU_THROW(module_config.status().message());
SPU_THROW("{}", module_config.status().message());
}

auto module = xla::HloModule::CreateFromProto(hlo_module, *module_config);
if (!module.status().ok()) {
SPU_THROW(module.status().message());
SPU_THROW("{}", module.status().message());
}

xla::runHloPasses((*module).get());
Expand All @@ -214,7 +214,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {

auto status = importer.Import(**module);
if (!status.ok()) {
SPU_THROW(status.message());
SPU_THROW("{}", status.message());
}

return mlir_hlo;
Expand Down
2 changes: 1 addition & 1 deletion libspu/device/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
(void)use_data;
(void)gen_crash_diag;
SPU_THROW(reason);
SPU_THROW("{}", reason);
}

std::mutex ErrorHandlerMutex;
Expand Down
28 changes: 22 additions & 6 deletions libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "libspu/core/bit_utils.h"
#include "libspu/core/context.h"
#include "libspu/core/trace.h"
#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/prot_wrapper.h"
Expand All @@ -43,6 +44,12 @@ inline bool _has_same_owner(const Value &x, const Value &y) {
return _get_owner(x) == _get_owner(y);
}

void _hint_nbits(const Value &a, size_t nbits) {
if (a.storage_type().isa<BShare>()) {
const_cast<Type &>(a.storage_type()).as<BShare>()->setNbits(nbits);
}
}

// generate inverse permutation
Index _inverse_index(const Index &p) {
Index q(p.size());
Expand Down Expand Up @@ -531,20 +538,29 @@ spu::Value _opt_apply_perm_ss(SPUContext *ctx, const spu::Value &perm,
std::vector<spu::Value> _bit_decompose(SPUContext *ctx, const spu::Value &x,
int64_t valid_bits) {
auto x_bshare = _prefer_b(ctx, x);
const auto k1 = _constant(ctx, 1U, x.shape());
std::vector<spu::Value> rets;
size_t nbits = valid_bits != -1
? static_cast<size_t>(valid_bits)
: x_bshare.storage_type().as<BShare>()->nbits();
rets.reserve(nbits);
_hint_nbits(x_bshare, nbits);
if (ctx->hasKernel("b2a_disassemble")) {
auto ret =
dynDispatch<std::vector<spu::Value>>(ctx, "b2a_disassemble", x_bshare);
return ret;
}

const auto k1 = _constant(ctx, 1U, x.shape());
std::vector<spu::Value> rets_b;
rets_b.reserve(nbits);

for (size_t bit = 0; bit < nbits; ++bit) {
auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
auto lowest_bit = _and(ctx, x_bshare_shift, k1);
rets.emplace_back(_prefer_a(ctx, lowest_bit));
rets_b.push_back(_and(ctx, x_bshare_shift, k1));
}

return rets;
std::vector<spu::Value> rets_a;
vmap(rets_b.begin(), rets_b.end(), std::back_inserter(rets_a),
[&](const Value &x) { return _prefer_a(ctx, x); });
return rets_a;
}

// Generate vector of bit decomposition of sorting keys
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/boolean_semi2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
if (comm->getRank() == 0) {
ring_xor_(x, in);
}

return makeBShare(x, field, getNumBits(in));
auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in);
return makeBShare(x, field, nbits);
}

NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
Expand Down
11 changes: 11 additions & 0 deletions libspu/mpc/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ void ConcateKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}

void DisassembleKernel::evaluate(KernelEvalContext* ctx) const {
const auto& in = ctx->getParam<Value>(0);
auto z = proc(ctx, UnwrapValue(in));

std::vector<Value> wrapped(z.size());
for (size_t idx = 0; idx < z.size(); ++idx) {
wrapped[idx] = WrapValue(z[idx]);
}
ctx->pushOutput(wrapped);
};

void OramOneHotKernel::evaluate(KernelEvalContext* ctx) const {
auto target = ctx->getParam<Value>(0);
auto s = ctx->getParam<int64_t>(1);
Expand Down
8 changes: 8 additions & 0 deletions libspu/mpc/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,12 @@ class ConcateKernel : public Kernel {
int64_t axis) const = 0;
};

class DisassembleKernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;

virtual std::vector<NdArrayRef> proc(KernelEvalContext* ctx,
const NdArrayRef& in) const = 0;
};

} // namespace spu::mpc
Loading
Loading