diff --git a/doc/source/api/nr.rst b/doc/source/api/nr.rst index 01e2bf6a..e20ae19d 100644 --- a/doc/source/api/nr.rst +++ b/doc/source/api/nr.rst @@ -106,6 +106,18 @@ PUSCHTransmitter :exclude-members: build, call :members: +PUSCHTransformDeprecoder +------------------------ +.. autoclass:: sionna.nr.PUSCHTransformDeprecoder + :exclude-members: call + :members: + +PUSCHTransformPrecoder +---------------------- +.. autoclass:: sionna.nr.PUSCHTransformPrecoder + :exclude-members: call + :members: + Transport Block *************** diff --git a/sionna/mimo/detection.py b/sionna/mimo/detection.py index 01373114..39f3e372 100644 --- a/sionna/mimo/detection.py +++ b/sionna/mimo/detection.py @@ -52,6 +52,11 @@ class LinearDetector(Layer): constellation point indices instead of soft-values. Defaults to `False`. + post_equalizer_transformation: None or Layer + Optional layer that applies a transformation after the equalizer and + before the demapper. This can be used to apply transform precoding + when DFT-s-OFDM is enabled in NR PUSCH. + dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). @@ -96,11 +101,13 @@ def __init__(self, num_bits_per_symbol=None, constellation=None, hard_out=False, + post_equalizer_transformation=None, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) self._output = output self._hard_out = hard_out + self._post_equalizer_transformation = post_equalizer_transformation # Determine the equalizer to use if isinstance(equalizer, str): @@ -137,6 +144,8 @@ def __init__(self, def call(self, inputs): x_hat, no_eff = self._equalizer(*inputs) + if self._post_equalizer_transformation is not None: + x_hat = self._post_equalizer_transformation(x_hat) z = self._demapper([x_hat, no_eff]) # Reshape to the expected output shape diff --git a/sionna/nr/__init__.py b/sionna/nr/__init__.py index fc3adafb..db179b21 100644 --- a/sionna/nr/__init__.py +++ b/sionna/nr/__init__.py @@ -11,11 +11,12 @@ from .pusch_dmrs_config import PUSCHDMRSConfig from .pusch_pilot_pattern import PUSCHPilotPattern from .pusch_precoder import PUSCHPrecoder +from .pusch_transform_precoder import PUSCHTransformPrecoder, PUSCHTransformDeprecoder from .pusch_transmitter import PUSCHTransmitter from .pusch_receiver import PUSCHReceiver from .pusch_channel_estimation import PUSCHLSChannelEstimator from .tb_config import TBConfig -from .utils import generate_prng_seq, select_mcs, calculate_tb_size +from .utils import generate_prng_seq, generate_low_papr_seq_type_1, select_mcs, calculate_tb_size from .tb_encoder import TBEncoder from .tb_decoder import TBDecoder from .layer_mapping import LayerMapper, LayerDemapper diff --git a/sionna/nr/pusch_config.py b/sionna/nr/pusch_config.py index c413dbf6..64c26f87 100644 --- a/sionna/nr/pusch_config.py +++ b/sionna/nr/pusch_config.py @@ -6,8 +6,9 @@ """ # pylint: disable=line-too-long +import functools import numpy as np -from .utils import generate_prng_seq +from .utils import generate_prng_seq, generate_low_papr_seq_type_1 from .config import Config from sionna import nr from .utils import calculate_tb_size @@ -233,7 +234,7 @@ def n_rnti(self, value): assert value in range(65536), "n_rnti must be in [0, 65535]" self._n_rnti = value - #---transform_precoding---# + #---precoding---# @property def precoding(self): """ @@ -427,9 +428,9 @@ def n(self): used for DMRS generation """ if self.dmrs.config_type==1: - n_max = self.num_resource_blocks*12//4 -1 + n_max = self.num_effective_subcarriers//4 -1 elif self.dmrs.config_type==2: - n_max = self.num_resource_blocks*12//6 -1 + n_max = self.num_effective_subcarriers//6 -1 return list(range(n_max+1)) @property @@ -450,6 +451,31 @@ def num_resource_blocks(self): else: return self.n_size_bwp + @property + def num_effective_resource_blocks(self): + """ + int, read-only : Number of allocated resource blocks for the + PUSCH transmissions, that are actually used (can differ from + num_subcarriers when transform precoding is enabled, + because of constraints on the largest prime factor of the + subcarrier count) + """ + @functools.lru_cache + def adjust_prbs_to_prime_factor_constraints(prbs): + # Decreases the number of PRBs until the largest prime factor is at most 5 + for eff_prbs in range(prbs, 1, -1): + n = eff_prbs + for p in [2, 3, 5]: + while n % p == 0: + n /= p + if n == 1: + return eff_prbs + + if self.transform_precoding: + return adjust_prbs_to_prime_factor_constraints(self.num_resource_blocks) + else: + return self.num_resource_blocks + @property def num_subcarriers(self): """ @@ -458,6 +484,17 @@ def num_subcarriers(self): """ return 12*self.num_resource_blocks + @property + def num_effective_subcarriers(self): + """ + int, read-only : Number of allocated subcarriers for the + PUSCH transmissions, that are actually used (can differ from + num_subcarriers when transform precoding is enabled, + because of constraints on the largest prime factor of the + subcarrier count) + """ + return 12 * self.num_effective_resource_blocks + @property def num_res_per_prb(self): """ @@ -488,7 +525,7 @@ def dmrs_mask(self): resource elements in the resource grid. `True` corresponds to resource elements on which no data is transmitted. """ - mask = np.zeros([self.num_subcarriers, + mask = np.zeros([self.num_effective_subcarriers, self.carrier.num_symbols_per_slot], dtype=bool) @@ -503,7 +540,7 @@ def dmrs_mask(self): cdm_ind[:,i] = np.array([0,1, 6, 7])+2*i for i in self.dmrs_symbol_indices: - for j in range(self.num_resource_blocks): + for j in range(self.num_effective_resource_blocks): for k in range(num_cdm_groups): mask[cdm_ind[:, k] + 12*j, i] = True return mask @@ -518,7 +555,7 @@ def dmrs_grid(self): This property returns for each configured DMRS port an empty resource grid filled with DMRS signals as defined in Section 6.4.1.1 [3GPP38211]. Not all possible options are implemented, - e.g., frequency hopping and transform precoding are not available. + e.g., frequency hopping is not available. This property provides the *unprecoded* DMRS for each configured DMRS port. Precoding might be applied to map the DMRS to the antenna ports. However, @@ -536,7 +573,7 @@ def dmrs_grid(self): # Generate empty resource grid for each port a_tilde = np.zeros([len(self.dmrs.dmrs_port_set), - self.num_subcarriers, + self.num_effective_subcarriers, self.carrier.num_symbols_per_slot], dtype=complex) @@ -546,15 +583,23 @@ def dmrs_grid(self): # For every l_prime for l_prime in self.l_prime: - # Compute c_init l = l_bar + l_prime - c_init = self.c_init(l) - # Generate RNG - c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init) + if self.transform_precoding: + if self.dmrs.n_sid is None: + n_id = self.carrier.n_cell_id + else: + n_id = self.dmrs.n_sid + r = generate_low_papr_seq_type_1(self.num_effective_subcarriers // 2, n_id % 30, 0, 0) + else: + # Compute c_init + c_init = self.c_init(l) + + # Generate RNG + c = generate_prng_seq(2*self.num_effective_subcarriers, c_init=c_init) - # Map to QAM - r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2])) + # Map to QAM + r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2])) # For every port in the dmrs port set for j_ind, _ in enumerate(self.dmrs.dmrs_port_set): @@ -625,8 +670,38 @@ def precoding_matrix(self): w /= np.sqrt(2) + # Table 6.3.1.5-2 + elif self.transform_precoding and self.num_antenna_ports == 4: + w = np.zeros([28, 4, 1], complex) + + # TPMI index 0-7 + w[:8,0,0] = [ 1, 0, 0, 0, 1, 1, 1, 1] + w[:8,1,0] = [ 0, 1, 0, 0, 0, 0, 0, 0] + w[:8,2,0] = [ 0, 0, 1, 0, 1, -1, 1j,-1j] + w[:8,3,0] = [ 0, 0, 0, 1, 0, 0, 0, 0] + + # TPMI index 8-15 + w[8:16,0,0] = [ 0, 0, 0, 0, 1, 1, 1, 1] + w[8:16,1,0] = [ 1, 1, 1, 1, 1, 1, 1, 1] + w[8:16,2,0] = [ 0, 0, 0, 0, 1, 1j, -1,-1j] + w[8:16,3,0] = [ 1, -1, 1j,-1j, -1, 1j, 1,-1j] + + # TPMI index 16-23 + w[16:24,0,0] = [ 1, 1, 1, 1, 1, 1, 1, 1] + w[16:24,1,0] = [ 1j, 1j, 1j, 1j, -1, -1, -1, -1] + w[16:24,2,0] = [ 1, 1j, -1,-1j, 1, 1j, -1,-1j] + w[16:24,3,0] = [ 1j, 1,-1j, -1, 1,-1j, -1, 1j] + + # TPMI index 24-27 + w[24:28,0,0] = [ 1, 1, 1, 1] + w[24:28,1,0] = [-1j,-1j,-1j,-1j] + w[24:28,2,0] = [ 1, 1j, -1,-1j] + w[24:28,3,0] = [-1j, -1, 1j, 1] + + w /= 2 + # Table 6.3.1.5-3 - elif self.num_antenna_ports==4: + elif not self.transform_precoding and self.num_antenna_ports==4: w = np.zeros([28,4,1], complex) # TPMI index 0-7 @@ -825,7 +900,7 @@ def num_coded_bits(self): n_re_per_prb = self.num_res_per_prb - self.num_ov # number of allocated REs - n_re = n_re_per_prb * self.num_resource_blocks + n_re = n_re_per_prb * self.num_effective_resource_blocks # total number of bits per slot num_coded_bits = int(self.tb.tb_scaling * self.tb.num_bits_per_symbol \ @@ -842,7 +917,7 @@ def tb_size(self): # number of allocated REs # the max. number of REs per PRB is limited to 156 in 38.214 - n_re = min(156, n_re_per_prb) * self.num_resource_blocks + n_re = min(156, n_re_per_prb) * self.num_effective_resource_blocks # include tb_scaling as defined in Tab. 5.1.3.2-2 38.214 target_tb_size = int(self.tb.target_coderate * self.tb.tb_scaling \ @@ -924,6 +999,14 @@ def check_config(self): assert self.num_layers == self.num_antenna_ports,\ "num_layers must be == num_antenna_ports" + if self.transform_precoding: + assert self.num_layers == 1,\ + "When transform precoding is used, only a single MIMO layer is supported" + assert self.dmrs.config_type == 1, \ + "When transform precoding is used, DMRS config type must be 1" + assert self.dmrs.num_cdm_groups_without_data == 2, \ + "When transform precoding is used, num_cdm_groups_without_data must be 2" + # Check Tables 6.4.1.1.3-3/4 are valid if self.dmrs.length==1: if self.mapping_type=="A": @@ -1033,11 +1116,13 @@ def check_pusch_configs(pusch_configs): "num_tx" : len(pusch_configs), "num_layers" : pc.num_layers, "num_subcarriers" : pc.num_subcarriers, + "num_effective_subcarriers": pc.num_effective_subcarriers, "num_ofdm_symbols" : pc.symbol_allocation[1], "subcarrier_spacing" : pc.carrier.subcarrier_spacing*1e3, "num_antenna_ports" : pc.num_antenna_ports, "precoding" : pc.precoding, "precoding_matrices" : [], + "transform_precoding" : pc.transform_precoding, "pusch_config" : pc, "carrier_config" : pc.carrier, "num_coded_bits" : pc.num_coded_bits, diff --git a/sionna/nr/pusch_dmrs_config.py b/sionna/nr/pusch_dmrs_config.py index 026c584b..789fa3cc 100644 --- a/sionna/nr/pusch_dmrs_config.py +++ b/sionna/nr/pusch_dmrs_config.py @@ -151,14 +151,29 @@ def n_id(self, value): if value is None: self._n_id = None elif isinstance(value, int): - assert value in list(range(65536)), "n_id must be in [0, 65535]" + assert value in range(65536), "n_id must be in [0, 65535]" self._n_id = [value, value] else: assert len(value)==2, "n_id must be either [] or a two-tuple" for e in value: - assert e in list(range(65536)), "Each element of n_id must be in [0, 65535]" + assert e in range(65536), "Each element of n_id must be in [0, 65535]" self._n_id = value + #---n_sid---# + @property + def n_sid(self): + r""" + None (default), [0,...,1007] : DMRS scrambling identity for DFT-s-OFDM + :math:`n_\text{ID}^\text{PUSCH}` + """ + self._ifndef("n_sid", None) + return self._n_sid + + @n_sid.setter + def n_sid(self, value): + assert value is None or (isinstance(value, int) and value in range(1008)), "n_sid must None or in [0, 1007]" + self._n_sid = value + #---n_scid---# @property def n_scid(self): diff --git a/sionna/nr/pusch_receiver.py b/sionna/nr/pusch_receiver.py index 996e1382..feebb45a 100644 --- a/sionna/nr/pusch_receiver.py +++ b/sionna/nr/pusch_receiver.py @@ -12,6 +12,7 @@ from sionna.ofdm import OFDMDemodulator, LinearDetector from sionna.utils import insert_dims from sionna.channel import time_to_ofdm_channel +from .pusch_transform_precoder import PUSCHTransformDeprecoder class PUSCHReceiver(Layer): # pylint: disable=line-too-long @@ -197,14 +198,19 @@ def __init__(self, # Use or create default MIMODetector if mimo_detector is None: # Default MIMO detector + transformation = PUSCHTransformDeprecoder(pusch_transmitter.resource_grid.num_effective_subcarriers, + dtype=dtype) if pusch_transmitter._transform_precoding else None self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog", - pusch_transmitter.resource_grid, - self._stream_management, - "qam", - pusch_transmitter._num_bits_per_symbol, - dtype=dtype) + pusch_transmitter.resource_grid, + self._stream_management, + "qam", + pusch_transmitter._num_bits_per_symbol, + post_equalizer_transformation=transformation, + dtype=dtype) else: # User-provided MIMO detector + if pusch_transmitter._transform_precoding: + print("WARNING: Using custom mimo detector which might not support transform precoding.") self._mimo_detector = mimo_detector # Create LayerDemapper @@ -248,7 +254,6 @@ def call(self, inputs): if self._input_domain=="time": h = time_to_ofdm_channel(h, self.resource_grid, self._l_min) - if self._w is not None: # Reshape h to put channel matrix dimensions last # [batch size, num_rx, num_tx, num_ofdm_symbols,... diff --git a/sionna/nr/pusch_transform_precoder.py b/sionna/nr/pusch_transform_precoder.py new file mode 100644 index 00000000..022993af --- /dev/null +++ b/sionna/nr/pusch_transform_precoder.py @@ -0,0 +1,115 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +"""Class definitions for PUSCH transform precoder and deprecoder""" + +import tensorflow as tf +from tensorflow.keras.layers import Layer + + +def _check_largest_prime_factor_not_larger_then_5(n): + for p in [2, 3, 5]: + while n % p == 0: + n /= p + if n > 1: + raise ValueError( + "Number of subcarriers shouldn't have a prime factor > 5") + + +class PUSCHTransformPrecoder(Layer): + r"""PUSCHTransformPrecoder(num_subcarriers, dtype=tf.complex64, **kwargs) + Performs transform precoding of layer mapped symbols as defined in + [3GPP38211]_ Sec. 6.3.1.4. + + The input will be reshaped into blocks of size ``num_subcarriers`` to which + the FFT will be applied individually. + + The class inherits from the Keras layer class and can be used as layer in a + Keras model. + + Parameters + ---------- + num_subcarriers: int + Number of subcarriers. The largest prime factor must not be larger + than 5. + + dtype : One of [tf.complex64, tf.complex128] + Dtype of inputs and outputs. Defaults to tf.complex64. + + Input + ----- + inputs: [...,n], tf.complex + Tensor containing the sequence of symbols to be transform precoded. + + Output + ------ + : [...,n], tf.complex + Tensor containing the sequence of symbols that have been transform + precoded. + """ + + def __init__(self, + num_subcarriers, + dtype=tf.complex64, + **kwargs): + super().__init__(dtype=dtype, **kwargs) + _check_largest_prime_factor_not_larger_then_5(num_subcarriers) + self._num_subcarriers = num_subcarriers + + def call(self, y): + orig_shape = tf.shape(y) + y_reshaped = tf.reshape(y, [-1, self._num_subcarriers]) + y_transformed = tf.cast(tf.sqrt(1 / self._num_subcarriers), + self._dtype) * tf.signal.fft(y_reshaped) + y_result = tf.reshape(y_transformed, orig_shape) + return y_result + + +class PUSCHTransformDeprecoder(Layer): + r"""PUSCHTransformDeprecoder(num_subcarriers, dtype=tf.complex64, **kwargs) + Performs transform deprecoding of layer mapped symbols as defined in + [3GPP38211]_ Sec. 6.3.1.4. + + The input will be reshaped into blocks of size ``num_subcarriers`` to which + the IFFT will be applied individually. + + The class inherits from the Keras layer class and can be used as layer in a + Keras model. + + Parameters + ---------- + num_subcarriers: int + Number of subcarriers. The largest prime factor must not be larger + than 5. + + dtype : One of [tf.complex64, tf.complex128] + Dtype of inputs and outputs. Defaults to tf.complex64. + + Input + ----- + inputs: [...,n], tf.complex + Tensor containing the sequence of symbols after transform precoding. + + Output + ------ + : [...,n], tf.complex + Tensor containing the sequence of symbols before transform + precoding. + """ + + def __init__(self, + num_subcarriers, + dtype=tf.complex64, + **kwargs): + super().__init__(dtype=dtype, **kwargs) + _check_largest_prime_factor_not_larger_then_5(num_subcarriers) + self._num_subcarriers = num_subcarriers + + def call(self, y): + orig_shape = tf.shape(y) + y_reshaped = tf.reshape(y, [-1, self._num_subcarriers]) + y_transformed = tf.cast(tf.sqrt(float(self._num_subcarriers)), + self._dtype) * tf.signal.ifft(y_reshaped) + y_result = tf.reshape(y_transformed, orig_shape) + return y_result diff --git a/sionna/nr/pusch_transmitter.py b/sionna/nr/pusch_transmitter.py index 7a4d62b8..4893b227 100644 --- a/sionna/nr/pusch_transmitter.py +++ b/sionna/nr/pusch_transmitter.py @@ -14,6 +14,7 @@ from .pusch_config import PUSCHConfig, check_pusch_configs from .pusch_pilot_pattern import PUSCHPilotPattern from .pusch_precoder import PUSCHPrecoder +from .pusch_transform_precoder import PUSCHTransformPrecoder from .tb_encoder import TBEncoder from .layer_mapping import LayerMapper @@ -169,9 +170,15 @@ def __init__(self, num_tx=self._num_tx, num_streams_per_tx=self._num_layers, cyclic_prefix_length=self._cyclic_prefix_length, + num_guard_carriers=(0, self._num_subcarriers - self._num_effective_subcarriers), pilot_pattern=self._pilot_pattern, dtype=dtype) + # Create PUSCHTransformPrecoder + if self._transform_precoding: + self._transform_precoder = PUSCHTransformPrecoder(self.resource_grid.num_effective_subcarriers, + dtype=dtype) + # Create ResourceGridMapper self._resource_grid_mapper = ResourceGridMapper(self._resource_grid, dtype=dtype) @@ -227,8 +234,14 @@ def call(self, inputs): # Map to layers x_layer = self._layer_mapper(x_map) + # (Optionally) apply PUSCH transform precoding (DFT-s-OFDM) + if self._transform_precoding: + x_trans_pre = self._transform_precoder(x_layer) + else: + x_trans_pre = x_layer + # Apply resource grid mapping - x_grid = self._resource_grid_mapper(x_layer) + x_grid = self._resource_grid_mapper(x_trans_pre) # (Optionally) apply PUSCH precoding if self._precoding=="codebook": diff --git a/sionna/nr/utils.py b/sionna/nr/utils.py index 9265c011..80ce5ce3 100644 --- a/sionna/nr/utils.py +++ b/sionna/nr/utils.py @@ -28,7 +28,7 @@ def generate_prng_seq(length, c_init): Note ---- The initialization sequence ``c_init`` is application specific and is - usually provided be higher layer protocols. + usually provided by higher layer protocols. """ # check inputs for consistency @@ -70,6 +70,225 @@ def generate_prng_seq(length, c_init): return c + +def generate_low_papr_seq_type_1(length, u, v, alpha): + r"""Implements low-PAPR sequence generator as defined in Sec. 5.2.2 + in [3GPP38211]_ based on Zadoff-Chu sequence. + + Parameters + ---------- + length: int + Desired output sequence length. + + u: int + Base sequence group. Must be in the range of 0 to 29. + + v: int + Base sequence number. Must be 0 if ``length`` < 72 or + in [0, 1] otherwise. + + alpha: float + Cyclic shift that will be applied to sequence + + Output + ------ + :[``length``], ndarray of floating point values + Containing the low-PAPR sequence. + + Note + ---- + The parameters ``u``, ``v`` and ``alpha`` are application specific + and are usually provided by higher layer protocols. + """ + + if not 0 <= u <= 29: + raise ValueError("u has to be between 0 and 29") + if (length < 72 and v != 0) or v not in [0, 1]: + raise ValueError("v has to be 0 if length < 72 or 0 or 1") + if alpha < 0: + raise ValueError("alpha has to be non-negative") + if length <= 0 or (length < 36 and length % 6 != 0): + raise ValueError( + "sequence length has to be positive and dividable by 6 when < 36") + + n_idx = np.arange(length) + if length < 30: + phi = _get_phi(u, length) + base_seq = np.exp(1j * phi * np.pi / 4) + elif length == 30: + base_seq = np.exp(-1j * np.pi * (u + 1) * (n_idx + 1) * + (n_idx + 2) / 31) + else: + n_zc = _largest_prime_lt_n(length) + q_bar = n_zc * (u + 1) / 31 + q = np.floor(q_bar + 0.5) + v * (-1) ** np.floor(2 * q_bar) + m_idx = np.arange(n_zc) + base_seq = np.exp(-1j * np.pi * q * m_idx * (m_idx + 1) / n_zc) + base_seq = np.pad(base_seq, (0, length - n_zc), 'wrap') + + seq = np.exp(1j * n_idx * alpha) * base_seq + return seq + + +def _largest_prime_lt_n(n): + """Return the largest prime number that is less than or equal to `n`.""" + # Only check uneven numbers + if n % 2 == 0: + n -= 1 + else: + n -= 2 + + for i in range(n, 2, -2): + for j in range(3, np.floor(np.sqrt(i)).astype(int) + 1, 2): + if i % j == 0: + break + else: + return i + + # when n == 3 + return 2 + + +def _get_phi(u, m): + """Return vector according to tables 5.2.2.2-x in [3GPP38211]_.""" + if m == 6: + phi_table = [ + [-3, -1, 3, 3, -1, -3], + [-3, 3, -1, -1, 3, -3], + [-3, -3, -3, 3, 1, -3], + [1, 1, 1, 3, -1, -3], + [1, 1, 1, -3, -1, 3], + [-3, 1, -1, -3, -3, -3], + [-3, 1, 3, -3, -3, -3], + [-3, -1, 1, -3, 1, -1], + [-3, -1, -3, 1, -3, -3], + [-3, -3, 1, -3, 3, -3], + [-3, 1, 3, 1, -3, -3], + [-3, -1, -3, 1, 1, -3], + [1, 1, 3, -1, -3, 3], + [1, 1, 3, 3, -1, 3], + [1, 1, 1, -3, 3, -1], + [1, 1, 1, -1, 3, -3], + [-3, -1, -1, -1, 3, -1], + [-3, -3, -1, 1, -1, -3], + [-3, -3, -3, 1, -3, -1], + [-3, 1, 1, -3, -1, -3], + [-3, 3, -3, 1, 1, -3], + [-3, 1, -3, -3, -3, -1], + [1, 1, -3, 3, 1, 3], + [1, 1, -3, -3, 1, -3], + [1, 1, 3, -1, 3, 3], + [1, 1, -3, 1, 3, 3], + [1, 1, -1, -1, 3, -1], + [1, 1, -1, 3, -1, -1], + [1, 1, -1, 3, -3, -1], + [1, 1, -3, 1, -1, -1] + ] + elif m == 12: + phi_table = [ + [-3, 1, -3, -3, -3, 3, -3, -1, 1, 1, 1, -3], + [-3, 3, 1, -3, 1, 3, -1, -1, 1, 3, 3, 3], + [-3, 3, 3, 1, -3, 3, -1, 1, 3, -3, 3, -3], + [-3, -3, -1, 3, 3, 3, -3, 3, -3, 1, -1, -3], + [-3, -1, -1, 1, 3, 1, 1, -1, 1, -1, -3, 1], + [-3, -3, 3, 1, -3, -3, -3, -1, 3, -1, 1, 3], + [1, -1, 3, -1, -1, -1, -3, -1, 1, 1, 1, -3], + [-1, -3, 3, -1, -3, -3, -3, -1, 1, -1, 1, -3], + [-3, -1, 3, 1, -3, -1, -3, 3, 1, 3, 3, 1], + [-3, -1, -1, -3, -3, -1, -3, 3, 1, 3, -1, -3], + [-3, 3, -3, 3, 3, -3, -1, -1, 3, 3, 1, -3], + [-3, -1, -3, -1, -1, -3, 3, 3, -1, -1, 1, -3], + [-3, -1, 3, -3, -3, -1, -3, 1, -1, -3, 3, 3], + [-3, 1, -1, -1, 3, 3, -3, -1, -1, -3, -1, -3], + [1, 3, -3, 1, 3, 3, 3, 1, -1, 1, -1, 3], + [-3, 1, 3, -1, -1, -3, -3, -1, -1, 3, 1, -3], + [-1, -1, -1, -1, 1, -3, -1, 3, 3, -1, -3, 1], + [-1, 1, 1, -1, 1, 3, 3, -1, -1, -3, 1, -3], + [-3, 1, 3, 3, -1, -1, -3, 3, 3, -3, 3, -3], + [-3, -3, 3, -3, -1, 3, 3, 3, -1, -3, 1, -3], + [3, 1, 3, 1, 3, -3, -1, 1, 3, 1, -1, -3], + [-3, 3, 1, 3, -3, 1, 1, 1, 1, 3, -3, 3], + [-3, 3, 3, 3, -1, -3, -3, -1, -3, 1, 3, -3], + [3, -1, -3, 3, -3, -1, 3, 3, 3, -3, -1, -3], + [-3, -1, 1, -3, 1, 3, 3, 3, -1, -3, 3, 3], + [-3, 3, 1, -1, 3, 3, -3, 1, -1, 1, -1, 1], + [-1, 1, 3, -3, 1, -1, 1, -1, -1, -3, 1, -1], + [-3, -3, 3, 3, 3, -3, -1, 1, -3, 3, 1, -3], + [1, -1, 3, 1, 1, -1, -1, -1, 1, 3, -3, 1], + [-3, 3, -3, 3, -3, -3, 3, -1, -1, 1, 3, -3] + ] + elif m == 18: + phi_table = [ + [-1, 3, -1, -3, 3, 1, -3, -1, 3, -3, -1, -1, 1, 1, 1, -1, -1, -1], + [3, -3, 3, -1, 1, 3, -3, -1, -3, -3, -1, -3, 3, 1, -1, 3, -3, 3], + [-3, 3, 1, -1, -1, 3, -3, -1, 1, 1, 1, 1, 1, -1, 3, -1, -3, -1], + [-3, -3, 3, 3, 3, 1, -3, 1, 3, 3, 1, -3, -3, 3, -1, -3, -1, 1], + [1, 1, -1, -1, -3, -1, 1, -3, -3, -3, 1, -3, -1, -1, 1, -1, 3, 1], + [3, -3, 1, 1, 3, -1, 1, -1, -1, -3, 1, 1, -1, 3, 3, -3, 3, -1], + [-3, 3, -1, 1, 3, 1, -3, -1, 1, 1, -3, 1, 3, 3, -1, -3, -3, -3], + [1, 1, -3, 3, 3, 1, 3, -3, 3, -1, 1, 1, -1, 1, -3, -3, -1, 3], + [-3, 1, -3, -3, 1, -3, -3, 3, 1, -3, -1, -3, -3, -3, -1, 1, 1, 3], + [3, -1, 3, 1, -3, -3, -1, 1, -3, -3, 3, 3, 3, 1, 3, -3, 3, -3], + [-3, -3, -3, 1, -3, 3, 1, 1, 3, -3, -3, 1, 3, -1, 3, -3, -3, 3], + [-3, -3, 3, 3, 3, -1, -1, -3, -1, -1, -1, 3, 1, -3, -3, -1, 3, -1], + [-3, -1, -3, -3, 1, 1, -1, -3, -1, -3, -1, -1, 3, 3, -1, 3, 1, 3], + [1, 1, -3, -3, -3, -3, 1, 3, -3, 3, 3, 1, -3, -1, 3, -1, -3, 1], + [-3, 3, -1, -3, -1, -3, 1, 1, -3, -3, -1, -1, 3, -3, 1, 3, 1, 1], + [3, 1, -3, 1, -3, 3, 3, -1, -3, -3, -1, -3, -3, 3, -3, -1, 1, 3], + [-3, -1, -3, -1, -3, 1, 3, -3, -1, 3, 3, 3, 1, -1, -3, 3, -1, -3], + [-3, -1, 3, 3, -1, 3, -1, -3, -1, 1, -1, -3, -1, -1, -1, 3, 3, 1], + [-3, 1, -3, -1, -1, 3, 1, -3, -3, -3, -1, -3, -3, 1, 1, 1, -1, -1], + [3, 3, 3, -3, -1, -3, -1, 3, -1, 1, -1, -3, 1, -3, -3, -1, 3, 3], + [-3, 1, 1, -3, 1, 1, 3, -3, -1, -3, -1, 3, -3, 3, -1, -1, -1, -3], + [1, -3, -1, -3, 3, 3, -1, -3, 1, -3, -3, -1, -3, -1, 1, 3, 3, 3], + [-3, -3, 1, -1, -1, 1, 1, -3, -1, 3, 3, 3, 3, -1, 3, 1, 3, 1], + [3, -1, -3, 1, -3, -3, -3, 3, 3, -1, 1, -3, -1, 3, 1, 1, 3, 3], + [3, -1, -1, 1, -3, -1, -3, -1, -3, -3, -1, -3, 1, 1, 1, -3, -3, 3], + [-3, -3, 1, -3, 3, 3, 3, -1, 3, 1, 1, -3, -3, -3, 3, -3, -1, -1], + [-3, -1, -1, -3, 1, -3, 3, -1, -1, -3, 3, 3, -3, -1, 3, -1, -1, -1], + [-3, -3, 3, 3, -3, 1, 3, -1, -3, 1, -1, -3, 3, -3, -1, -1, -1, 3], + [-1, -3, 1, -3, -3, -3, 1, 1, 3, 3, -3, 3, 3, -3, -1, 3, -3, 1], + [-3, 3, 1, -1, -1, -1, -1, 1, -1, 3, 3, -3, -1, 1, 3, -1, 3, -1] + ] + elif m == 24: + phi_table = [ + [-1,-3,3,-1,3,1,3,-1,1,-3,-1,-3,-1,1,3,-3,-1,-3,3,3,3,-3,-3,-3], + [-1,-3,3,1,1,-3,1,-3,-3,1,-3,-1,-1,3,-3,3,3,3,-3,1,3,3,-3,-3], + [-1,-3,-3,1,-1,-1,-3,1,3,-1,-3,-1,-1,-3,1,1,3,1,-3,-1,-1,3,-3,-3], + [1,-3,3,-1,-3,-1,3,3,1,-1,1,1,3,-3,-1,-3,-3,-3,-1,3,-3,-1,-3,-3], + [-1,3,-3,-3,-1,3,-1,-1,1,3,1,3,-1,-1,-3,1,3,1,-1,-3,1,-1,-3,-3], + [-3,-1,1,-3,-3,1,1,-3,3,-1,-1,-3,1,3,1,-1,-3,-1,-3,1,-3,-3,-3,-3], + [-3,3,1,3,-1,1,-3,1,-3,1,-1,-3,-1,-3,-3,-3,-3,-1,-1,-1,1,1,-3,-3], + [-3,1,3,-1,1,-1,3,-3,3,-1,-3,-1,-3,3,-1,-1,-1,-3,-1,-1,-3,3,3,-3], + [-3,1,-3,3,-1,-1,-1,-3,3,1,-1,-3,-1,1,3,-1,1,-1,1,-3,-3,-3,-3,-3], + [1,1,-1,-3,-1,1,1,-3,1,-1,1,-3,3,-3,-3,3,-1,-3,1,3,-3,1,-3,-3], + [-3,-3,-3,-1,3,-3,3,1,3,1,-3,-1,-1,-3,1,1,3,1,-1,-3,3,1,3,-3], + [-3,3,-1,3,1,-1,-1,-1,3,3,1,1,1,3,3,1,-3,-3,-1,1,-3,1,3,-3], + [3,-3,3,-1,-3,1,3,1,-1,-1,-3,-1,3,-3,3,-1,-1,3,3,-3,-3,3,-3,-3], + [-3,3,-1,3,-1,3,3,1,1,-3,1,3,-3,3,-3,-3,-1,1,3,-3,-1,-1,-3,-3], + [-3,1,-3,-1,-1,3,1,3,-3,1,-1,3,3,-1,-3,3,-3,-1,-1,-3,-3,-3,3,-3], + [-3,-1,-1,-3,1,-3,-3,-1,-1,3,-1,1,-1,3,1,-3,-1,3,1,1,-1,-1,-3,-3], + [-3,-3,1,-1,3,3,-3,-1,1,-1,-1,1,1,-1,-1,3,-3,1,-3,1,-1,-1,-1,-3], + [3,-1,3,-1,1,-3,1,1,-3,-3,3,-3,-1,-1,-1,-1,-1,-3,-3,-1,1,1,-3,-3], + [-3,1,-3,1,-3,-3,1,-3,1,-3,-3,-3,-3,-3,1,-3,-3,1,1,-3,1,1,-3,-3], + [-3,-3,3,3,1,-1,-1,-1,1,-3,-1,1,-1,3,-3,-1,-3,-1,-1,1,-3,3,-1,-3], + [-3,-3,-1,-1,-1,-3,1,-1,-3,-1,3,-3,1,-3,3,-3,3,3,1,-1,-1,1,-3,-3], + [3,-1,1,-1,3,-3,1,1,3,-1,-3,3,1,-3,3,-1,-1,-1,-1,1,-3,-3,-3,-3], + [-3,1,-3,3,-3,1,-3,3,1,-1,-3,-1,-3,-3,-3,-3,1,3,-1,1,3,3,3,-3], + [-3,-1,1,-3,-1,-1,1,1,1,3,3,-1,1,-1,1,-1,-1,-3,-3,-3,3,1,-1,-3], + [-3,3,-1,-3,-1,-1,-1,3,-1,-1,3,-3,-1,3,-3,3,-3,-1,3,1,1,-1,-3,-3], + [-3,1,-1,-3,-3,-1,1,-3,-1,-3,1,1,-1,1,1,3,3,3,-1,1,-1,1,-1,-3], + [-1,3,-1,-1,3,3,-1,-1,-1,3,-1,-3,1,3,1,1,-3,-3,-3,-1,-3,-1,-3,-3], + [3,-3,-3,-1,3,3,-3,-1,3,1,1,1,3,-1,3,-3,-1,3,-1,3,1,-1,-3,-3], + [-3,1,-3,1,-3,1,1,3,1,-3,-3,-1,1,3,-1,-3,3,1,-1,-3,-3,-3,-3,-3], + [3,-3,-1,1,3,-1,-1,-3,-1,3,-1,-3,-1,-3,3,-1,3,1,1,-3,3,-3,-3,-3] + ] + else: + raise ValueError("Invalid u") + + return np.array(phi_table[u]) + + def select_mcs(mcs_index, table_index=1, channel_type="PUSCH", diff --git a/sionna/ofdm/detection.py b/sionna/ofdm/detection.py index ddcf0843..18ed5b02 100644 --- a/sionna/ofdm/detection.py +++ b/sionna/ofdm/detection.py @@ -828,6 +828,11 @@ class LinearDetector(OFDMDetector): constellation point indices instead of soft-values. Defaults to `False`. + post_equalizer_transformation: None or Layer + Optional layer that applies a transformation after the equalizer and + before the demapper. This can be used to apply transform precoding + when DFT-s-OFDM is enabled in NR PUSCH. + dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of `y`. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). @@ -878,6 +883,7 @@ def __init__(self, num_bits_per_symbol=None, constellation=None, hard_out=False, + post_equalizer_transformation=None, dtype=tf.complex64, **kwargs): @@ -889,6 +895,7 @@ def __init__(self, num_bits_per_symbol=num_bits_per_symbol, constellation=constellation, hard_out=hard_out, + post_equalizer_transformation=post_equalizer_transformation, dtype=dtype, **kwargs) diff --git a/test/unit/nr/pusch_transform_precoding_270_prbs.npz b/test/unit/nr/pusch_transform_precoding_270_prbs.npz new file mode 100644 index 00000000..b34f2f84 Binary files /dev/null and b/test/unit/nr/pusch_transform_precoding_270_prbs.npz differ diff --git a/test/unit/nr/pusch_transform_precoding_2_prbs.npz b/test/unit/nr/pusch_transform_precoding_2_prbs.npz new file mode 100644 index 00000000..a50f3f47 Binary files /dev/null and b/test/unit/nr/pusch_transform_precoding_2_prbs.npz differ diff --git a/test/unit/nr/pusch_transmitter_transform_precoding.npz b/test/unit/nr/pusch_transmitter_transform_precoding.npz new file mode 100644 index 00000000..80f1bbaf Binary files /dev/null and b/test/unit/nr/pusch_transmitter_transform_precoding.npz differ diff --git a/test/unit/nr/reference_dmrs_transform_precoding.npy b/test/unit/nr/reference_dmrs_transform_precoding.npy new file mode 100644 index 00000000..4033dae4 Binary files /dev/null and b/test/unit/nr/reference_dmrs_transform_precoding.npy differ diff --git a/test/unit/nr/test_nr_utils.py b/test/unit/nr/test_nr_utils.py index 3f9aa920..7ff951b2 100644 --- a/test/unit/nr/test_nr_utils.py +++ b/test/unit/nr/test_nr_utils.py @@ -22,7 +22,7 @@ except RuntimeError as e: print(e) -from sionna.nr.utils import select_mcs, generate_prng_seq, calculate_tb_size +from sionna.nr.utils import select_mcs, generate_prng_seq, generate_low_papr_seq_type_1, calculate_tb_size class TestNRUtils(unittest.TestCase): @@ -310,6 +310,64 @@ def test_gen_rand_seq(self): s = generate_prng_seq(l, c_init+1) self.assertFalse(np.array_equal(s, s_ref)) + def test_generate_low_papr_seq_type_1(self): + # test against invalid inputs + testcases = [[36, 30, 0, 0], [5, 0, 0, 0], [36, 0, 0, -1], [36, 0, 1, 0], [72, 0, 2, 0]] + for inputs in testcases: + with self.assertRaises(ValueError): + generate_low_papr_seq_type_1(*inputs) + + testcases = [ + [(6, 1, 0, 2), + [-0.70711 - 0.70711j, -0.34871 - 0.93723j, -0.99734 - 0.072944j, 0.48137 - 0.87652j, -0.5967 - 0.80247j, + 0.20863 + 0.97799j]], + [(12, 28, 0, 3), + [0.70711 + 0.70711j, -0.60024 + 0.79982j, -0.48137 + 0.87652j, -0.93568 - 0.35285j, 0.97611 + 0.21728j, + -0.077358 + 0.997j, -0.064114 - 0.99794j, 0.2043 + 0.97891j, 0.94028 - 0.3404j, -0.46969 - 0.88283j, + -0.80772 + 0.58957j, -0.71643 + 0.69766j]], + [(30, 23, 0, 4), + [0.15143 + 0.98847j, -0.3916 + 0.92014j, -0.6933 - 0.72065j, 0.49314 + 0.86995j, 0.91416 - 0.40534j, + 0.8911 - 0.4538j, 0.62623 + 0.77964j, -0.85955 - 0.51104j, -0.026691 + 0.99964j, 0.5932 + 0.80506j, + -0.12174 + 0.99256j, -0.74711 - 0.6647j, 0.3808 + 0.92466j, 0.99599 - 0.089513j, 0.99824 + 0.059348j, + -0.056355 + 0.99841j, -0.098464 - 0.99514j, -0.91885 + 0.39461j, -0.64888 + 0.76089j, + -0.99548 - 0.094925j, 0.78507 - 0.61941j, -0.99992 + 0.012268j, -0.4719 + 0.88165j, -0.74674 + 0.66512j, + -0.50378 - 0.86383j, 0.46204 + 0.88686j, 0.83391 - 0.55189j, 0.66673 - 0.7453j, 0.94878 + 0.31594j, + -0.97159 + 0.23666j]], + [(36, 20, 0, 1), + [1 + 0j, -0.99342 + 0.11451j, -0.22459 + 0.97445j, -0.85411 + 0.52009j, 0.6491 - 0.76071j, + -0.66374 - 0.74796j, -0.1308 - 0.99141j, 0.60622 + 0.7953j, 0.75484 - 0.65591j, 0.94815 - 0.31783j, + -0.50082 + 0.86555j, 0.96696 + 0.25493j, 0.90173 + 0.4323j, -0.88771 + 0.4604j, 0.81219 + 0.58339j, + 0.81995 + 0.57244j, -0.86847 + 0.49575j, 0.92868 + 0.37088j, 0.9866 + 0.16313j, -0.39291 + 0.91958j, + 0.8911 - 0.4538j, 0.62956 - 0.77695j, 0.75296 + 0.65806j, -0.35158 - 0.93616j, -0.8309 - 0.55642j, + 0.41199 - 0.91119j, -0.6558 + 0.75493j, 0.10869 + 0.99408j, -0.88837 + 0.45913j, 0.92525 - 0.37935j, + 0.15425 - 0.98803j, 0.91474 - 0.40404j, -0.86246 + 0.50612j, 0.18828 + 0.98212j, -0.57115 + 0.82084j, + 0.2864 - 0.95811j]], + [(100, 15, 1, 1), + [1 + 0j, -0.6689 - 0.74335j, -0.056579 - 0.9984j, -0.44178 + 0.89713j, -0.72417 + 0.68962j, + 0.84155 - 0.54019j, 0.85653 - 0.5161j, -0.78018 + 0.62556j, -0.56418 + 0.82565j, 0.14153 - 0.98993j, + -0.45945 - 0.8882j, 0.95169 + 0.30706j, 0.80735 - 0.59007j, 0.16479 + 0.98633j, 0.99048 + 0.13769j, + -0.27622 + 0.96109j, 0.96648 + 0.25674j, -0.077426 + 0.997j, 0.96468 - 0.26343j, 0.698 + 0.7161j, + 0.12987 - 0.99153j, 0.76506 - 0.64396j, -0.99272 + 0.12041j, -0.95642 - 0.29201j, 0.85179 + 0.52389j, + 0.79936 + 0.60085j, -0.83877 - 0.54448j, -0.94106 - 0.33824j, 0.99887 - 0.047603j, 0.82408 - 0.56648j, + -0.24937 + 0.96841j, 0.58627 + 0.81012j, -0.99539 + 0.095917j, -0.26902 + 0.96313j, + -0.88751 - 0.46079j, -0.49987 + 0.8661j, -0.91868 - 0.39501j, -0.12643 + 0.99198j, + -0.95096 + 0.30932j, 0.79418 + 0.60768j, 0.11175 + 0.99374j, 0.50704 - 0.86192j, 0.84872 - 0.52885j, + -0.97196 + 0.23516j, -0.99684 + 0.079376j, 0.99652 - 0.083377j, 0.96905 - 0.24685j, + -0.83793 + 0.54578j, -0.48263 + 0.87583j, -0.14757 - 0.98905j, -0.82023 - 0.57203j, 0.93353 - 0.3585j, + 0.066509 - 0.99779j, 0.94347 + 0.33145j, 0.43243 - 0.90167j, 0.92316 + 0.38443j, 0.17908 - 0.98383j, + 0.98077 - 0.19516j, -0.67046 - 0.74195j, 0.13521 - 0.99082j, -0.7474 + 0.66438j, -0.98383 + 0.17912j, + 0.97915 + 0.20313j, 0.91011 + 0.41437j, -0.88327 - 0.46887j, -0.92611 - 0.37726j, 0.99236 + 0.1234j, + 0.95494 - 0.2968j, -0.63071 + 0.77602j, 0.066408 + 0.99779j, -0.829 - 0.55924j, -0.8873 + 0.46118j, + -0.14278 - 0.98975j, -0.99952 - 0.030838j, 0.043009 - 0.99907j, -0.99431 + 0.10653j, + -0.40633 - 0.91373j, -0.62781 + 0.77837j, -0.99907 - 0.043146j, 0.68481 + 0.72872j, + 0.14251 + 0.98979j, 0.30164 - 0.95342j, 0.55959 - 0.82877j, -0.65895 + 0.75219j, -0.63114 + 0.77567j, + 0.46593 - 0.88482j, 0.12408 - 0.99227j, 0.3874 + 0.92191j, 0.88288 + 0.46959j, -0.93686 + 0.34971j, + -0.20407 + 0.97896j, -0.82174 - 0.56986j, -0.74812 + 0.66356j, -0.60408 - 0.79692j, + -0.74277 + 0.66954j, -0.83078 - 0.5566j, -0.18043 + 0.98359j, -0.92515 + 0.37961j, 0.90102 + 0.43379j, + 0.43134 + 0.90219j]] + ] + for inputs, outputs in testcases: + np.testing.assert_array_almost_equal(generate_low_papr_seq_type_1(*inputs), outputs, decimal=5) def test_tb_size(self): """Test TB size calculation""" diff --git a/test/unit/nr/test_pusch_config.py b/test/unit/nr/test_pusch_config.py index b27a2222..94c3ccd6 100644 --- a/test/unit/nr/test_pusch_config.py +++ b/test/unit/nr/test_pusch_config.py @@ -28,7 +28,7 @@ class TestPUSCHDMRS(unittest.TestCase): """Tests for the PUSCHDMRS Class""" def test_against_reference_1(self): - """Test that DMRS pattenrs match a reference implementation""" + """Test that DMRS patterns match a reference implementation""" reference_dmrs = np.load("unit/nr/reference_dmrs_1.npy") pusch_config = PUSCHConfig() pusch_config.carrier.n_size_grid = 1 @@ -52,7 +52,7 @@ def test_against_reference_1(self): self.assertTrue(np.allclose(pilots, reference_dmrs)) def test_against_reference_2(self): - """Test that DMRS pattenrs match a reference implementation""" + """Test that DMRS patterns match a reference implementation""" reference_dmrs = np.load("unit/nr/reference_dmrs_2.npy") pusch_config = PUSCHConfig() pusch_config.carrier.n_size_grid = 4 @@ -75,9 +75,31 @@ def test_against_reference_2(self): pilots = np.transpose(np.array(p)) self.assertTrue(np.allclose(pilots, reference_dmrs)) + def test_against_reference_transform_precoding(self): + """Test that DMRS patterns match a reference implementation""" + reference_dmrs = np.load("unit/nr/reference_dmrs_transform_precoding.npy") + pusch_config = PUSCHConfig() + pusch_config.transform_precoding = True + + pusch_config.carrier.subcarrier_spacing = 30 + pusch_config.carrier.n_size_grid = 273 + pusch_config.dmrs.config_type = 1 + pusch_config.dmrs.length = 1 + pusch_config.dmrs.additional_position = 0 + pusch_config.dmrs.num_cdm_groups_without_data = 2 + p = [] + for n_cell_id in [0, 1, 10, 24, 99, 1006]: + pusch_config.carrier.n_cell_id = n_cell_id + a = pusch_config.dmrs_grid + pilots = np.concatenate([a[0, :, 2], a[0, :, 3], a[0, :, 10], a[0, :, 11]]) + pilots = pilots[np.where(pilots)]/np.sqrt(2) + p.append(pilots) + pilots = np.transpose(np.array(p)) + self.assertTrue(np.allclose(pilots, reference_dmrs)) + def test_orthogonality_over_resource_grid(self): """Test that DMRS for different ports are orthogonal - accross a resource grid by computing the LS estimate + across a resource grid by computing the LS estimate on a noise less block-constant channel """ def ls_estimate(pusch_config): @@ -179,7 +201,7 @@ def ls_estimate(pusch_config): def test_precoding_against_reference(self): - "Test precoded DMRS against reference implementation" + """Test precoded DMRS against reference implementation""" pusch_config = PUSCHConfig() pusch_config.carrier.n_size_grid = 1 diff --git a/test/unit/nr/test_pusch_receiver.py b/test/unit/nr/test_pusch_receiver.py index c6556125..60359d77 100644 --- a/test/unit/nr/test_pusch_receiver.py +++ b/test/unit/nr/test_pusch_receiver.py @@ -43,7 +43,7 @@ def run_test(pusch_configs, channel_estimator="perfect", domain="freq", num_rx=1 stream_management = None if num_rx==2: rx_tx_association = np.eye(2, dtype=bool) - stream_management = StreamManagement(rx_tx_association, pusch_config.num_layers) + stream_management = StreamManagement(rx_tx_association, pusch_configs[0].num_layers) pusch_receiver = PUSCHReceiver(pusch_transmitter, stream_management=stream_management, @@ -289,3 +289,14 @@ def test_07(self): self.assertEqual(ber, 0.0) ber = run_test(pusch_configs, channel_estimator=None, domain="time", dtype=tf.complex128) self.assertEqual(ber, 0.0) + + def test_08(self): + """Transform precoding""" + tf.random.set_seed(1) + pusch_config = PUSCHConfig() + pusch_config.n_size_bwp = 273 + pusch_config.tb.mcs_index = 16 + pusch_config.transform_precoding = True + pusch_configs = [pusch_config] + ber = run_test(pusch_configs, channel_estimator=None, batch_size=2) + self.assertEqual(ber, 0.0) diff --git a/test/unit/nr/test_pusch_transform_precoder.py b/test/unit/nr/test_pusch_transform_precoder.py new file mode 100644 index 00000000..752fb63d --- /dev/null +++ b/test/unit/nr/test_pusch_transform_precoder.py @@ -0,0 +1,53 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +try: + import sionna +except ImportError as e: + import sys + + sys.path.append("../") + +import unittest +import numpy as np +import tensorflow as tf + +gpus = tf.config.list_physical_devices('GPU') +print('Number of GPUs available :', len(gpus)) +if gpus: + gpu_num = 0 # Number of the GPU to be used + try: + tf.config.set_visible_devices(gpus[gpu_num], 'GPU') + print('Only GPU number', gpu_num, 'used.') + tf.config.experimental.set_memory_growth(gpus[gpu_num], True) + except RuntimeError as e: + print(e) + +from sionna.nr import PUSCHTransformPrecoder, PUSCHTransformDeprecoder + + +class TestPUSCHTransformPrecoder(unittest.TestCase): + """Test PUSCHTransformPrecoder and PUSCHTransformDeprecoder""" + + def test_precoder_against_reference(self): + for prbs in [2, 270]: + ref_data = np.load(f"unit/nr/pusch_transform_precoding_{prbs}_prbs.npz") + tp = PUSCHTransformPrecoder(num_subcarriers=12 * prbs) + x_transform_precoded = tp(ref_data["x_layer_mapped"]) + np.testing.assert_array_almost_equal(x_transform_precoded, + ref_data["x_transform_precoded"]) + + def test_deprecoder_against_reference(self): + for prbs in [2, 270]: + ref_data = np.load(f"unit/nr/pusch_transform_precoding_{prbs}_prbs.npz") + tp = PUSCHTransformDeprecoder(num_subcarriers=12 * prbs) + x_layer_mapped = tp(ref_data["x_transform_precoded"]) + np.testing.assert_array_almost_equal(x_layer_mapped, + ref_data["x_layer_mapped"]) + + def test_invalid_subcarrier_count(self): + with self.assertRaises(ValueError): + PUSCHTransformPrecoder(num_subcarriers=273 * 12) + with self.assertRaises(ValueError): + PUSCHTransformDeprecoder(num_subcarriers=273 * 12) diff --git a/test/unit/nr/test_pusch_transmitter.py b/test/unit/nr/test_pusch_transmitter.py index d482b546..a469b27c 100644 --- a/test/unit/nr/test_pusch_transmitter.py +++ b/test/unit/nr/test_pusch_transmitter.py @@ -76,3 +76,22 @@ def tests_against_reference(self): for i in range(0,83): test_name = f"unit/nr/pusch_test_configs/test_{i}" self.assertTrue(run_test(test_name)) + + def test_against_reference_transform_precoding(self): + """Test PUSCHTransmitter output against reference MATLAB implementation + with transform precoding enabled""" + pusch_config = PUSCHConfig() + pusch_config.carrier.subcarrier_spacing = 30 + pusch_config.carrier.n_size_grid = 273 + pusch_config.carrier.n_cell_id = 1 + pusch_config.n_rnti = 42 + pusch_config.tb.mcs_index = 9 + pusch_config.transform_precoding = True + pusch_config.dmrs.n_sid = 3 + + ref_data = np.load("unit/nr/pusch_transmitter_transform_precoding.npz") + + pusch_transmitter = PUSCHTransmitter(pusch_config, return_bits=False) + x_grid = pusch_transmitter(ref_data["bits"]) + + np.testing.assert_array_almost_equal(x_grid, ref_data["grid"])