From 6f5e239d4eef109c2cb7e47fc586b24e01b0c89c Mon Sep 17 00:00:00 2001 From: Riccardo Balbo Date: Sat, 27 Apr 2024 19:55:59 +0200 Subject: [PATCH] improve i/o performances --- src/OpenAgentsNode.py | 162 +++++++++++++++++++++++++++++++++++------- src/main.py | 77 +++++++++++++------- 2 files changed, 189 insertions(+), 50 deletions(-) diff --git a/src/OpenAgentsNode.py b/src/OpenAgentsNode.py index f5ed69e..a92215d 100644 --- a/src/OpenAgentsNode.py +++ b/src/OpenAgentsNode.py @@ -7,6 +7,51 @@ import json import asyncio import pickle +import queue +import concurrent + +class BlobWriter : + def __init__(self,writeQueue,res ): + self.writeQueue = writeQueue + self.res = res + + async def write(self, data): + self.writeQueue.put_nowait(data) + + async def writeInt(self, data): + self.writeQueue.put_nowait(data.to_bytes(4, byteorder='big')) + + async def end(self): + self.writeQueue.put_nowait(None) + + async def close(self): + self.writeQueue.put_nowait(None) + res= await self.res + return res.success + +class BlobReader: + def __init__(self, chunksQueue , req): + self.chunksQueue = chunksQueue + self.buffer = bytearray() + self.req = req + + + async def read(self, n = 1): + while len(self.buffer) < n: + v = await self.chunksQueue.get() + if v is None: break + self.buffer.extend(v) + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + async def readInt(self): + return int.from_bytes(await self.read(4), byteorder='big') + + + async def close(self): + self.chunksQueue.task_done() + return await self.req + class BlobStorage: def __init__(self, id, url, node): @@ -27,7 +72,7 @@ async def delete(self, path): async def writeBytes(self, path, dataBytes): client = self.node.getClient() - CHUNK_SIZE = 1024*1024*100 + CHUNK_SIZE = 1024*1024*15 def write_data(): for j in range(0, len(dataBytes), CHUNK_SIZE): chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))]) @@ -36,6 +81,39 @@ def write_data(): res=await client.diskWriteFile(write_data()) return res.success + + async def openWriteStream(self, path): + client = self.node.getClient() + writeQueue = asyncio.Queue() + CHUNK_SIZE = 1024*1024*15 + + + async def write_data(): + while True: + dataBytes = await writeQueue.get() + if dataBytes is None: # End of stream + break + for j in range(0, len(dataBytes), CHUNK_SIZE): + chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))]) + request = rpc_pb2.RpcDiskWriteFileRequest(diskId=str(self.id), path=path, data=chunk) + yield request + writeQueue.task_done() + + res=client.diskWriteFile(write_data()) + + return BlobWriter(writeQueue, res) + + + async def openReadStream(self, path): + client = self.node.getClient() + readQueue = asyncio.Queue() + + async def read_data(): + async for chunk in client.diskReadFile(rpc_pb2.RpcDiskReadFileRequest(diskId=self.id, path=path)): + readQueue.put_nowait(chunk.data) + r = asyncio.create_task(read_data()) + return BlobReader(readQueue, r) + async def readBytes(self, path): client = self.node.getClient() bytesOut = bytearray() @@ -69,6 +147,7 @@ class JobRunner: _meta = None _sockets = None _nextAnnouncementTimestamp = 0 + cachePath = None def __init__(self, filters, meta, template, sockets): self._filters = filters @@ -84,40 +163,64 @@ def __init__(self, filters, meta, template, sockets): self._sockets = json.dumps(sockets) else: self._sockets = sockets + + self.cachePath = os.getenv('CACHE_PATH', os.path.join(os.path.dirname(__file__), "cache")) + if not os.path.exists(self.cachePath): + os.makedirs(self.cachePath) - async def cacheSet(self, path, value, version=0, expireAt=0): + async def cacheSet(self, path, value, version=0, expireAt=0, local=False): try: dataBytes = pickle.dumps(value) - client = self._node.getClient() - CHUNK_SIZE = 1024*1024*100 - def write_data(): - for j in range(0, len(dataBytes), CHUNK_SIZE): - chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))]) - request = rpc_pb2.RpcCacheSetRequest( - key=path, - data=chunk, - expireAt=expireAt, - version=version - ) - yield request - res=await client.cacheSet(write_data()) - return res.success + if local: + fullPath = os.path.join(self.cachePath, path) + with open(fullPath, "wb") as f: + f.write(dataBytes) + with open(fullPath+".meta.json", "w") as f: + f.write(json.dumps({"version":version, "expireAt":expireAt})) + else: + client = self._node.getClient() + CHUNK_SIZE = 1024*1024*15 + def write_data(): + for j in range(0, len(dataBytes), CHUNK_SIZE): + chunk = bytes(dataBytes[j:min(j+CHUNK_SIZE, len(dataBytes))]) + request = rpc_pb2.RpcCacheSetRequest( + key=path, + data=chunk, + expireAt=expireAt, + version=version + ) + yield request + res=await client.cacheSet(write_data()) + return res.success except Exception as e: print("Error setting cache "+str(e)) return False - async def cacheGet(self, path, lastVersion = 0): + async def cacheGet(self, path, lastVersion = 0, local=False): try: - client = self._node.getClient() - bytesOut = bytearray() - stream = client.cacheGet(rpc_pb2.RpcCacheGetRequest(key=path, lastVersion = lastVersion)) - async for chunk in stream: - if not chunk.exists: + if local: + fullPath = os.path.join(self.cachePath, path) + if not os.path.exists(fullPath) or not os.path.exists(fullPath+".meta.json"): + return None + with open(fullPath+".meta.json", "r") as f: + meta = json.loads(f.read()) + if lastVersion > 0 and meta["version"] != lastVersion: return None - bytesOut.extend(chunk.data) - return pickle.loads(bytesOut) + if meta["expireAt"] > 0 and time.time()*1000 > meta["expireAt"]: + return None + with open(fullPath, "rb") as f: + return pickle.load(f) + else: + client = self._node.getClient() + bytesOut = bytearray() + stream = client.cacheGet(rpc_pb2.RpcCacheGetRequest(key=path, lastVersion = lastVersion)) + async for chunk in stream: + if not chunk.exists: + return None + bytesOut.extend(chunk.data) + return pickle.loads(bytesOut) except Exception as e: print("Error getting cache "+str(e)) return None @@ -223,10 +326,17 @@ def getClient(self): except Exception as e: print("Error closing channel "+str(e)) print("Connect to "+self.poolAddress+":"+str(self.poolPort)+" with ssl "+str(self.poolSsl)) + + options=[ + # 20 MB + ('grpc.max_send_message_length', 1024*1024*20), + ('grpc.max_receive_message_length', 1024*1024*20) + ] + if self.poolSsl: - self.channel = grpc.aio.secure_channel(self.poolAddress+":"+str(self.poolPort), grpc.ssl_channel_credentials()) + self.channel = grpc.aio.secure_channel(self.poolAddress+":"+str(self.poolPort), grpc.ssl_channel_credentials(),options) else: - self.channel = grpc.aio.insecure_channel(self.poolAddress+":"+str(self.poolPort)) + self.channel = grpc.aio.insecure_channel(self.poolAddress+":"+str(self.poolPort),options) self.rpcClient = rpc_pb2_grpc.PoolConnectorStub(self.channel) return self.rpcClient diff --git a/src/main.py b/src/main.py index a0572f9..15ba199 100644 --- a/src/main.py +++ b/src/main.py @@ -21,35 +21,63 @@ def __init__(self, filters, meta, template, sockets): self.MAX_MEMORY_CACHE_GB = float(os.getenv('MAX_MEMORY_CACHE_GB', self.MAX_MEMORY_CACHE_GB)) + async def loadEmbeddingsFromBlobstore(self, i, blobStorage, f, out_vectors, out_content): + # Binary read + sentence_bytes=await blobStorage.readBytes(f) + vectors_bytes=await blobStorage.readBytes(f+".vectors") + shape_bytes=await blobStorage.readBytes(f+".shape") + dtype_bytes=await blobStorage.readBytes(f+".dtype") + # sentence_marker_bytes=blobStorage.readBytes(f+".kind") + # Decode + sentence = sentence_bytes.decode("utf-8") + dtype = dtype_bytes.decode("utf-8") + shape = json.loads(shape_bytes.decode("utf-8")) + embeddings = np.frombuffer(vectors_bytes, dtype=dtype).reshape(shape) + out_vectors[i] = embeddings + out_content[i] = sentence + + + async def deserializeFromBlob(self, url, out_vectors , out_content): - blobStorage = await self.openStorage( url) - files = await blobStorage.list() + blobDisk = await self.openStorage( url) self.log("Reading embeddings from "+url) - for f in files: - print(f) - + # Find embeddings files - embeddings_files = [f for f in files if f.endswith(".embeddings")] - sentences = [] - vectors = [] + sentencesIn = await blobDisk.openReadStream("sentences.bin") + embeddingsIn = await blobDisk.openReadStream("embeddings.bin") + + # embeddings_files = [f for f in files if f.endswith(".embeddings")] + # self.log("Found "+str(len(embeddings_files))+" embeddings files") + # sentences = [] + # vectors = [] dtype = None shape = None - print("Found files "+str(embeddings_files)) - for f in embeddings_files: - # Binary read - sentence_bytes=await blobStorage.readBytes(f) - vectors_bytes=await blobStorage.readBytes(f+".vectors") - shape_bytes=await blobStorage.readBytes(f+".shape") - dtype_bytes=await blobStorage.readBytes(f+".dtype") - # sentence_marker_bytes=blobStorage.readBytes(f+".kind") - # Decode - sentence = sentence_bytes.decode("utf-8") - dtype = dtype_bytes.decode("utf-8") - shape = json.loads(shape_bytes.decode("utf-8")) - embeddings = np.frombuffer(vectors_bytes, dtype=dtype).reshape(shape) - out_vectors.append(embeddings) + + nSentences = await sentencesIn.readInt() + for i in range(nSentences): + lenSentence = await sentencesIn.readInt() + sentence = await sentencesIn.read(lenSentence) + sentence=sentence.decode() out_content.append(sentence) - await blobStorage.close() + + nEmbeddings = await embeddingsIn.readInt() + for i in range(nEmbeddings): + shape = [] + lenShape = await embeddingsIn.readInt() + for j in range(lenShape): + shape.append(await embeddingsIn.readInt()) + + lenDtype = await embeddingsIn.readInt() + dtype = (await embeddingsIn.read(lenDtype)).decode() + + lenBs = await embeddingsIn.readInt() + bs = await embeddingsIn.read(lenBs) + embeddings = np.frombuffer(bs, dtype=dtype).reshape(shape) + out_vectors.append(embeddings) + + + + await blobDisk.close() return [dtype,shape] async def deserializeFromJSON( self, data, out_vectors ,out_content): @@ -150,7 +178,7 @@ def getParamValue(key,default=None): index = self.INDEXES.get(indexId) if not index: - self.log("Preparing index") + self.log("Loading index") index_vectors = [] index_content = [] dtype = None @@ -160,6 +188,7 @@ def getParamValue(key,default=None): continue [dtype,shape] = await self.deserialize(jin,index_vectors ,index_content) + self.log("Preparing index") index_vectors = np.array(index_vectors) if normalize and dtype == "float32": faiss.normalize_L2(index_vectors)