Skip to content

Commit

Permalink
Add basic storage array.
Browse files Browse the repository at this point in the history
commit-id:2da700e2
  • Loading branch information
gilbens-starkware committed Jul 8, 2024
1 parent 9e72a1c commit 02a0d68
Show file tree
Hide file tree
Showing 4 changed files with 452 additions and 26 deletions.
3 changes: 3 additions & 0 deletions corelib/src/starknet/storage.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use starknet::storage_access::StorageBaseAddress;
use starknet::SyscallResult;
use starknet::storage_access::storage_base_address_from_felt252;

mod array;
pub use array::{StorageArray, StorageArrayTrait, MutableStorageArrayTrait};


/// A pointer to an address in storage, can be used to read and write values, if the generic type
/// supports it (e.g. basic types like `felt252`).
Expand Down
114 changes: 114 additions & 0 deletions corelib/src/starknet/storage/array.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use super::{
StorageAsPath, StorageAsPointer, StoragePath, StoragePointer0Offset, Mutable, StoragePathTrait,
StoragePathUpdateTrait
};

/// A type to represent an array in storage. The length of the storage is stored in the storage
/// base, while the elements are stored in hash(storage_base, index).
#[phantom]
pub struct StorageArray<T> {}

impl StorageArrayDrop<T> of Drop<StorageArray<T>> {}
impl StorageArrayCopy<T> of Copy<StorageArray<T>> {}

/// Implement as_ptr for StorageArray.
impl StorageArrayAsPointer<T> of StorageAsPointer<StoragePath<StorageArray<T>>> {
type Value = u64;
fn as_ptr(self: @StoragePath<StorageArray<T>>) -> StoragePointer0Offset<u64> {
StoragePointer0Offset { address: (*self).finalize() }
}
}

/// Implement as_ptr for Mutable<StorageArray>.
impl MutableStorageArrayAsPointer<T> of StorageAsPointer<StoragePath<Mutable<StorageArray<T>>>> {
type Value = Mutable<u64>;
fn as_ptr(self: @StoragePath<Mutable<StorageArray<T>>>) -> StoragePointer0Offset<Mutable<u64>> {
StoragePointer0Offset { address: (*self).finalize() }
}
}


/// Trait for the interface of a storage array.
pub trait StorageArrayTrait<T> {
type ElementType;
fn at(self: T, index: u64) -> StoragePath<Self::ElementType>;
fn len(self: T) -> u64;
}

/// Implement `StorageArrayTrait` for `StoragePath<StorageArray<T>>`.
impl StorageArrayImpl<T> of StorageArrayTrait<StoragePath<StorageArray<T>>> {
type ElementType = T;
fn at(self: StoragePath<StorageArray<T>>, index: u64) -> StoragePath<T> {
let array_len = self.len();
assert!(index < array_len, "Index out of bounds");
self.update(index).into()
}
fn len(self: StoragePath<StorageArray<T>>) -> u64 {
self.as_ptr().read()
}
}

/// Implement `StorageArrayTrait` for any type that implements StorageAsPath into a storage path
/// that implements StorageArrayTrait.
impl PathableStorageArrayImpl<
T,
+Drop<T>,
impl PathImpl: StorageAsPath<T>,
impl ArrayTraitImpl: StorageArrayTrait<StoragePath<PathImpl::Value>>
> of StorageArrayTrait<T> {
type ElementType = ArrayTraitImpl::ElementType;
fn at(self: T, index: u64) -> StoragePath<ArrayTraitImpl::ElementType> {
self.as_path().at(index)
}
fn len(self: T) -> u64 {
self.as_path().len()
}
}

/// Trait for the interface of a mutable storage array.
pub trait MutableStorageArrayTrait<T> {
type ElementType;
fn at(self: T, index: u64) -> StoragePath<Mutable<Self::ElementType>>;
fn len(self: T) -> u64;
fn append(self: T) -> StoragePath<Mutable<Self::ElementType>>;
}

/// Implement `MutableStorageArrayTrait` for `StoragePath<Mutable<StorageArray<T>>`.
impl MutableStorageArrayImpl<
T, +Drop<T>
> of MutableStorageArrayTrait<StoragePath<Mutable<StorageArray<T>>>> {
type ElementType = T;
fn at(self: StoragePath<Mutable<StorageArray<T>>>, index: u64) -> StoragePath<Mutable<T>> {
let array_len = self.len();
assert!(index < array_len, "Index out of bounds");
self.update(index).into()
}
fn len(self: StoragePath<Mutable<StorageArray<T>>>) -> u64 {
self.as_ptr().read()
}
fn append(self: StoragePath<Mutable<StorageArray<T>>>) -> StoragePath<Mutable<T>> {
let array_len = self.len();
self.as_ptr().write(array_len + 1);
self.at(array_len)
}
}

/// Implement `MutableStorageArrayTrait` for any type that implements StorageAsPath into a storage
/// path that implements MutableStorageArrayTrait.
impl PathableMutableStorageArrayImpl<
T,
+Drop<T>,
impl PathImpl: StorageAsPath<T>,
impl ArrayTraitImpl: MutableStorageArrayTrait<StoragePath<PathImpl::Value>>
> of MutableStorageArrayTrait<T> {
type ElementType = ArrayTraitImpl::ElementType;
fn at(self: T, index: u64) -> StoragePath<Mutable<ArrayTraitImpl::ElementType>> {
self.as_path().at(index)
}
fn len(self: T) -> u64 {
self.as_path().len()
}
fn append(self: T) -> StoragePath<Mutable<ArrayTraitImpl::ElementType>> {
self.as_path().append()
}
}
123 changes: 122 additions & 1 deletion crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::utils::{deserialized, serialized};
use core::integer::BoundedInt;
use core::num::traits::Zero;
use core::byte_array::ByteArrayTrait;
use starknet::storage::StorageArray;

impl StorageAddressPartialEq of PartialEq<StorageAddress> {
fn eq(lhs: @StorageAddress, rhs: @StorageAddress) -> bool {
Expand Down Expand Up @@ -78,15 +79,24 @@ struct NonZeros {
value_felt252: NonZero<felt252>,
}

#[starknet::storage_node]
struct StorageArrays {
array: StorageArray<u32>,
array_of_arrays: StorageArray<StorageArray<u32>>,
}

#[starknet::contract]
mod test_contract {
use super::{AbcEtc, ByteArrays, NonZeros};
use core::starknet::storage::StoragePointerWriteAccess;
use super::{AbcEtc, ByteArrays, NonZeros, StorageArrays,};
use starknet::storage::{StorageArrayTrait, MutableStorageArrayTrait, StorageAsPath,};

#[storage]
struct Storage {
data: AbcEtc,
byte_arrays: ByteArrays,
non_zeros: NonZeros,
arrays: StorageArrays,
}

#[external(v0)]
Expand Down Expand Up @@ -128,6 +138,46 @@ mod test_contract {
pub fn get_non_zeros(self: @ContractState) -> NonZeros {
self.non_zeros.read()
}

#[external(v0)]
pub fn append_to_array(ref self: ContractState, value: u32) {
self.arrays.array.append().write(value);
}

#[external(v0)]
pub fn get_array_length(self: @ContractState) -> u64 {
self.arrays.array.len()
}

#[external(v0)]
pub fn get_array_element(self: @ContractState, index: u64) -> u32 {
self.arrays.array.at(index).read()
}

#[external(v0)]
pub fn append_an_array(ref self: ContractState) {
self.arrays.array_of_arrays.append();
}

#[external(v0)]
pub fn append_to_nested_array(ref self: ContractState, index: u64, value: u32) {
self.arrays.array_of_arrays.at(index).append().write(value);
}

#[external(v0)]
pub fn get_array_of_arrays_length(self: @ContractState) -> u64 {
self.arrays.array_of_arrays.len()
}

#[external(v0)]
pub fn get_nested_array_length(self: @ContractState, index: u64) -> u64 {
self.arrays.array_of_arrays.at(index).len()
}

#[external(v0)]
pub fn get_nested_array_element(self: @ContractState, index: u64, nested_index: u64) -> u32 {
self.arrays.array_of_arrays.at(index).at(nested_index).read()
}
}

#[test]
Expand Down Expand Up @@ -224,3 +274,74 @@ fn test_read_write_non_zero() {
assert!(test_contract::__external::set_non_zeros(serialized(x.clone())).is_empty());
assert_eq!(deserialized(test_contract::__external::get_non_zeros(serialized(()))), x);
}

#[test]
fn test_storage_array() {
assert!(test_contract::__external::append_to_array(serialized(1_u32)).is_empty());
assert!(test_contract::__external::append_to_array(serialized(2_u32)).is_empty());
assert!(test_contract::__external::append_to_array(serialized(3_u32)).is_empty());
assert_eq!(deserialized(test_contract::__external::get_array_length(serialized(()))), 3);
assert_eq!(deserialized(test_contract::__external::get_array_element(serialized(0_u64))), 1);
assert_eq!(deserialized(test_contract::__external::get_array_element(serialized(1_u64))), 2);
assert_eq!(deserialized(test_contract::__external::get_array_element(serialized(2_u64))), 3);
}

#[test]
fn test_storage_array_of_arrays() {
assert!(test_contract::__external::append_an_array(serialized(())).is_empty());
assert!(
test_contract::__external::append_to_nested_array(serialized((0_u64, 1_u32))).is_empty()
);
assert!(
test_contract::__external::append_to_nested_array(serialized((0_u64, 2_u32))).is_empty()
);
assert!(
test_contract::__external::append_to_nested_array(serialized((0_u64, 3_u32))).is_empty()
);
assert!(test_contract::__external::append_an_array(serialized(())).is_empty());
assert!(
test_contract::__external::append_to_nested_array(serialized((1_u64, 4_u32))).is_empty()
);
assert!(
test_contract::__external::append_to_nested_array(serialized((1_u64, 5_u32))).is_empty()
);
assert_eq!(
deserialized(test_contract::__external::get_array_of_arrays_length(serialized(()))), 2
);
assert_eq!(
deserialized(test_contract::__external::get_nested_array_length(serialized(0_u64))), 3
);
assert_eq!(
deserialized(
test_contract::__external::get_nested_array_element(serialized((0_u64, 0_u64)))
),
1
);
assert_eq!(
deserialized(
test_contract::__external::get_nested_array_element(serialized((0_u64, 1_u64)))
),
2
);
assert_eq!(
deserialized(
test_contract::__external::get_nested_array_element(serialized((0_u64, 2_u64)))
),
3
);
assert_eq!(
deserialized(test_contract::__external::get_nested_array_length(serialized(1_u64))), 2
);
assert_eq!(
deserialized(
test_contract::__external::get_nested_array_element(serialized((1_u64, 0_u64)))
),
4
);
assert_eq!(
deserialized(
test_contract::__external::get_nested_array_element(serialized((1_u64, 1_u64)))
),
5
);
}
Loading

0 comments on commit 02a0d68

Please sign in to comment.