-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model_define_tf_csinet.py
128 lines (99 loc) · 4.43 KB
/
Model_define_tf_csinet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import math
#This part realizes the quantization and dequantization operations.
#The output of the encoder must be the bitstream.
def Num2Bit(Num, B):
Num_ = Num.numpy()
bit = (np.unpackbits(np.array(Num_, np.uint8), axis=1).reshape(-1, Num_.shape[1], 8)[:, :, 4:]).reshape(-1, Num_.shape[1] * B)
bit.astype(np.float32)
return tf.convert_to_tensor(bit, dtype=tf.float32)
def Bit2Num(Bit, B):
Bit_ = Bit.numpy()
Bit_.astype(np.float32)
Bit_ = np.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B])
num = np.zeros(shape=np.shape(Bit_[:, :, 1]))
for i in range(B):
num = num + Bit_[:, :, i] * 2 ** (B - 1 - i)
return tf.cast(num, dtype=tf.float32)
@tf.custom_gradient
def QuantizationOp(x, B):
step = tf.cast((2 ** B), dtype=tf.float32)
result = tf.cast((tf.round(x * step - 0.5)), dtype=tf.float32)
result = tf.py_function(func=Num2Bit, inp=[result, B], Tout=tf.float32)
def custom_grad(dy):
grad = dy
return (grad, grad)
return result, custom_grad
class QuantizationLayer(tf.keras.layers.Layer):
def __init__(self, B,**kwargs):
self.B = B
super(QuantizationLayer, self).__init__()
def call(self, x):
return QuantizationOp(x, self.B)
def get_config(self):
# Implement get_config to enable serialization. This is optional.
base_config = super(QuantizationLayer, self).get_config()
base_config['B'] = self.B
return base_config
@tf.custom_gradient
def DequantizationOp(x, B):
x = tf.py_function(func=Bit2Num, inp=[x, B], Tout=tf.float32)
step = tf.cast((2 ** B), dtype=tf.float32)
result = tf.cast((x + 0.5) / step, dtype=tf.float32)
def custom_grad(dy):
grad = dy * 1
return (grad, grad)
return result, custom_grad
class DeuantizationLayer(tf.keras.layers.Layer):
def __init__(self, B,**kwargs):
self.B = B
super(DeuantizationLayer, self).__init__()
def call(self, x):
return DequantizationOp(x, self.B)
def get_config(self):
base_config = super(DeuantizationLayer, self).get_config()
base_config['B'] = self.B
return base_config
# More details about the neural networks can be found in [1].
# [1] C. Wen, W. Shih and S. Jin, "Deep Learning for Massive MIMO CSI Feedback,"
# in IEEE Wireless Communications Letters, vol. 7, no. 5, pp. 748-751, Oct. 2018, doi: 10.1109/LWC.2018.2818160.
def Encoder(x, feedback_bits, trainable=True):
B = 4
with tf.compat.v1.variable_scope('Encoder'):
x = layers.Conv2D(6, 3, padding='same', activation='relu',trainable=trainable)(x)
x = layers.Conv2D(6, 3, padding='same', activation='relu',trainable=trainable)(x)
x = layers.Flatten()(x)
x = layers.Dense(units=int(feedback_bits // B), activation='sigmoid',trainable=trainable)(x)
encoder_output = QuantizationLayer(B)(x)
return encoder_output
def Decoder(x,feedback_bits, trainable=True):
B = 4
decoder_input = DeuantizationLayer(B)(x)
x = tf.reshape(decoder_input, (-1, int(feedback_bits//B)))
x = layers.Dense(32256, activation='sigmoid',trainable=trainable)(x)
x_ini = layers.Reshape((126, 128, 2))(x)
for i in range(3):
x = layers.Conv2D(8, 3, padding='SAME', activation='relu',trainable=trainable)(x_ini)
x = layers.Conv2D(16, 3, padding='SAME', activation='relu',trainable=trainable)(x)
x = layers.Conv2D(2, 3, padding='SAME', activation='relu',trainable=trainable)(x)
x_ini = keras.layers.Add()([x_ini, x])
decoder_output = layers.Conv2D(2, 3, padding='SAME',activation="sigmoid",trainable=trainable)(x_ini)
return decoder_output
def NMSE(x, x_hat):
x_real = np.reshape(x[:, :, :, 0], (len(x), -1))
x_imag = np.reshape(x[:, :, :, 1], (len(x), -1))
x_hat_real = np.reshape(x_hat[:, :, :, 0], (len(x_hat), -1))
x_hat_imag = np.reshape(x_hat[:, :, :, 1], (len(x_hat), -1))
x_C = x_real - 0.5 + 1j * (x_imag - 0.5)
x_hat_C = x_hat_real - 0.5 + 1j * (x_hat_imag - 0.5)
power = np.sum(abs(x_C) ** 2, axis=1)
mse = np.sum(abs(x_C - x_hat_C) ** 2, axis=1)
nmse = np.mean(mse / power)
return nmse
# Return keywords of your own custom layers to ensure that model
# can be successfully loaded in test file.
def get_custom_objects():
return {"QuantizationLayer":QuantizationLayer,"DeuantizationLayer":DeuantizationLayer}