Skip to content

Commit

Permalink
improve i/o performances
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardobl committed Apr 27, 2024
1 parent 19d3479 commit 6f5e239
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 50 deletions.
162 changes: 136 additions & 26 deletions src/OpenAgentsNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,51 @@
import json
import asyncio
import pickle
import queue
import concurrent

class BlobWriter :
def __init__(self,writeQueue,res ):
self.writeQueue = writeQueue
self.res = res

async def write(self, data):
self.writeQueue.put_nowait(data)

async def writeInt(self, data):
self.writeQueue.put_nowait(data.to_bytes(4, byteorder='big'))

async def end(self):
self.writeQueue.put_nowait(None)

async def close(self):
self.writeQueue.put_nowait(None)
res= await self.res
return res.success

class BlobReader:
def __init__(self, chunksQueue , req):
self.chunksQueue = chunksQueue
self.buffer = bytearray()
self.req = req


async def read(self, n = 1):
while len(self.buffer) < n:
v = await self.chunksQueue.get()
if v is None: break
self.buffer.extend(v)
result, self.buffer = self.buffer[:n], self.buffer[n:]
return result

async def readInt(self):
return int.from_bytes(await self.read(4), byteorder='big')


async def close(self):
self.chunksQueue.task_done()
return await self.req


class BlobStorage:
def __init__(self, id, url, node):
Expand All @@ -27,7 +72,7 @@ async def delete(self, path):

async def writeBytes(self, path, dataBytes):
client = self.node.getClient()
CHUNK_SIZE = 1024*1024*100
CHUNK_SIZE = 1024*1024*15
def write_data():
for j in range(0, len(dataBytes), CHUNK_SIZE):
chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))])
Expand All @@ -36,6 +81,39 @@ def write_data():
res=await client.diskWriteFile(write_data())
return res.success


async def openWriteStream(self, path):
client = self.node.getClient()
writeQueue = asyncio.Queue()
CHUNK_SIZE = 1024*1024*15


async def write_data():
while True:
dataBytes = await writeQueue.get()
if dataBytes is None: # End of stream
break
for j in range(0, len(dataBytes), CHUNK_SIZE):
chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))])
request = rpc_pb2.RpcDiskWriteFileRequest(diskId=str(self.id), path=path, data=chunk)
yield request
writeQueue.task_done()

res=client.diskWriteFile(write_data())

return BlobWriter(writeQueue, res)


async def openReadStream(self, path):
client = self.node.getClient()
readQueue = asyncio.Queue()

async def read_data():
async for chunk in client.diskReadFile(rpc_pb2.RpcDiskReadFileRequest(diskId=self.id, path=path)):
readQueue.put_nowait(chunk.data)
r = asyncio.create_task(read_data())
return BlobReader(readQueue, r)

async def readBytes(self, path):
client = self.node.getClient()
bytesOut = bytearray()
Expand Down Expand Up @@ -69,6 +147,7 @@ class JobRunner:
_meta = None
_sockets = None
_nextAnnouncementTimestamp = 0
cachePath = None
def __init__(self, filters, meta, template, sockets):
self._filters = filters

Expand All @@ -84,40 +163,64 @@ def __init__(self, filters, meta, template, sockets):
self._sockets = json.dumps(sockets)
else:
self._sockets = sockets

self.cachePath = os.getenv('CACHE_PATH', os.path.join(os.path.dirname(__file__), "cache"))
if not os.path.exists(self.cachePath):
os.makedirs(self.cachePath)


async def cacheSet(self, path, value, version=0, expireAt=0):
async def cacheSet(self, path, value, version=0, expireAt=0, local=False):
try:
dataBytes = pickle.dumps(value)
client = self._node.getClient()
CHUNK_SIZE = 1024*1024*100
def write_data():
for j in range(0, len(dataBytes), CHUNK_SIZE):
chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))])
request = rpc_pb2.RpcCacheSetRequest(
key=path,
data=chunk,
expireAt=expireAt,
version=version
)
yield request
res=await client.cacheSet(write_data())
return res.success
if local:
fullPath = os.path.join(self.cachePath, path)
with open(fullPath, "wb") as f:
f.write(dataBytes)
with open(fullPath+".meta.json", "w") as f:
f.write(json.dumps({"version":version, "expireAt":expireAt}))
else:
client = self._node.getClient()
CHUNK_SIZE = 1024*1024*15
def write_data():
for j in range(0, len(dataBytes), CHUNK_SIZE):
chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))])
request = rpc_pb2.RpcCacheSetRequest(
key=path,
data=chunk,
expireAt=expireAt,
version=version
)
yield request
res=await client.cacheSet(write_data())
return res.success
except Exception as e:
print("Error setting cache "+str(e))
return False


async def cacheGet(self, path, lastVersion = 0):
async def cacheGet(self, path, lastVersion = 0, local=False):
try:
client = self._node.getClient()
bytesOut = bytearray()
stream = client.cacheGet(rpc_pb2.RpcCacheGetRequest(key=path, lastVersion = lastVersion))
async for chunk in stream:
if not chunk.exists:
if local:
fullPath = os.path.join(self.cachePath, path)
if not os.path.exists(fullPath) or not os.path.exists(fullPath+".meta.json"):
return None
with open(fullPath+".meta.json", "r") as f:
meta = json.loads(f.read())
if lastVersion > 0 and meta["version"] != lastVersion:
return None
bytesOut.extend(chunk.data)
return pickle.loads(bytesOut)
if meta["expireAt"] > 0 and time.time()*1000 > meta["expireAt"]:
return None
with open(fullPath, "rb") as f:
return pickle.load(f)
else:
client = self._node.getClient()
bytesOut = bytearray()
stream = client.cacheGet(rpc_pb2.RpcCacheGetRequest(key=path, lastVersion = lastVersion))
async for chunk in stream:
if not chunk.exists:
return None
bytesOut.extend(chunk.data)
return pickle.loads(bytesOut)
except Exception as e:
print("Error getting cache "+str(e))
return None
Expand Down Expand Up @@ -223,10 +326,17 @@ def getClient(self):
except Exception as e:
print("Error closing channel "+str(e))
print("Connect to "+self.poolAddress+":"+str(self.poolPort)+" with ssl "+str(self.poolSsl))

options=[
# 20 MB
('grpc.max_send_message_length', 1024*1024*20),
('grpc.max_receive_message_length', 1024*1024*20)
]

if self.poolSsl:
self.channel = grpc.aio.secure_channel(self.poolAddress+":"+str(self.poolPort), grpc.ssl_channel_credentials())
self.channel = grpc.aio.secure_channel(self.poolAddress+":"+str(self.poolPort), grpc.ssl_channel_credentials(),options)
else:
self.channel = grpc.aio.insecure_channel(self.poolAddress+":"+str(self.poolPort))
self.channel = grpc.aio.insecure_channel(self.poolAddress+":"+str(self.poolPort),options)
self.rpcClient = rpc_pb2_grpc.PoolConnectorStub(self.channel)
return self.rpcClient

Expand Down
77 changes: 53 additions & 24 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,63 @@ def __init__(self, filters, meta, template, sockets):
self.MAX_MEMORY_CACHE_GB = float(os.getenv('MAX_MEMORY_CACHE_GB', self.MAX_MEMORY_CACHE_GB))


async def loadEmbeddingsFromBlobstore(self, i, blobStorage, f, out_vectors, out_content):
# Binary read
sentence_bytes=await blobStorage.readBytes(f)
vectors_bytes=await blobStorage.readBytes(f+".vectors")
shape_bytes=await blobStorage.readBytes(f+".shape")
dtype_bytes=await blobStorage.readBytes(f+".dtype")
# sentence_marker_bytes=blobStorage.readBytes(f+".kind")
# Decode
sentence = sentence_bytes.decode("utf-8")
dtype = dtype_bytes.decode("utf-8")
shape = json.loads(shape_bytes.decode("utf-8"))
embeddings = np.frombuffer(vectors_bytes, dtype=dtype).reshape(shape)
out_vectors[i] = embeddings
out_content[i] = sentence



async def deserializeFromBlob(self, url, out_vectors , out_content):
blobStorage = await self.openStorage( url)
files = await blobStorage.list()
blobDisk = await self.openStorage( url)
self.log("Reading embeddings from "+url)
for f in files:
print(f)


# Find embeddings files
embeddings_files = [f for f in files if f.endswith(".embeddings")]
sentences = []
vectors = []
sentencesIn = await blobDisk.openReadStream("sentences.bin")
embeddingsIn = await blobDisk.openReadStream("embeddings.bin")

# embeddings_files = [f for f in files if f.endswith(".embeddings")]
# self.log("Found "+str(len(embeddings_files))+" embeddings files")
# sentences = []
# vectors = []
dtype = None
shape = None
print("Found files "+str(embeddings_files))
for f in embeddings_files:
# Binary read
sentence_bytes=await blobStorage.readBytes(f)
vectors_bytes=await blobStorage.readBytes(f+".vectors")
shape_bytes=await blobStorage.readBytes(f+".shape")
dtype_bytes=await blobStorage.readBytes(f+".dtype")
# sentence_marker_bytes=blobStorage.readBytes(f+".kind")
# Decode
sentence = sentence_bytes.decode("utf-8")
dtype = dtype_bytes.decode("utf-8")
shape = json.loads(shape_bytes.decode("utf-8"))
embeddings = np.frombuffer(vectors_bytes, dtype=dtype).reshape(shape)
out_vectors.append(embeddings)

nSentences = await sentencesIn.readInt()
for i in range(nSentences):
lenSentence = await sentencesIn.readInt()
sentence = await sentencesIn.read(lenSentence)
sentence=sentence.decode()
out_content.append(sentence)
await blobStorage.close()

nEmbeddings = await embeddingsIn.readInt()
for i in range(nEmbeddings):
shape = []
lenShape = await embeddingsIn.readInt()
for j in range(lenShape):
shape.append(await embeddingsIn.readInt())

lenDtype = await embeddingsIn.readInt()
dtype = (await embeddingsIn.read(lenDtype)).decode()

lenBs = await embeddingsIn.readInt()
bs = await embeddingsIn.read(lenBs)
embeddings = np.frombuffer(bs, dtype=dtype).reshape(shape)
out_vectors.append(embeddings)



await blobDisk.close()
return [dtype,shape]

async def deserializeFromJSON( self, data, out_vectors ,out_content):
Expand Down Expand Up @@ -150,7 +178,7 @@ def getParamValue(key,default=None):

index = self.INDEXES.get(indexId)
if not index:
self.log("Preparing index")
self.log("Loading index")
index_vectors = []
index_content = []
dtype = None
Expand All @@ -160,6 +188,7 @@ def getParamValue(key,default=None):
continue
[dtype,shape] = await self.deserialize(jin,index_vectors ,index_content)

self.log("Preparing index")
index_vectors = np.array(index_vectors)
if normalize and dtype == "float32":
faiss.normalize_L2(index_vectors)
Expand Down

0 comments on commit 6f5e239

Please sign in to comment.