diff --git a/src/OpenAgentsNode.py b/src/OpenAgentsNode.py index a92215d..112eb6d 100644 --- a/src/OpenAgentsNode.py +++ b/src/OpenAgentsNode.py @@ -1,4 +1,6 @@ import grpc +import logging + from openagents_grpc_proto import rpc_pb2_grpc from openagents_grpc_proto import rpc_pb2 import time @@ -8,8 +10,20 @@ import asyncio import pickle import queue +import base64 import concurrent - +from threading import Condition +import requests + +def cnvLogLevel(logLevel): + if logLevel == "debug": return logging.DEBUG + if logLevel == "info": return logging.INFO + if logLevel == "warn": return logging.WARNING + if logLevel == "error": return logging.ERROR + if logLevel == "fine": return logging.DEBUG + if logLevel == "finer": return logging.DEBUG + if logLevel == "finest": return logging.DEBUG + return logging.DEBUG class BlobWriter : def __init__(self,writeQueue,res ): self.writeQueue = writeQueue @@ -137,18 +151,22 @@ def getUrl(self): return self.url class JobRunner: - _filters = None - _node = None - _job = None - _disksByUrl = {} - _disksById = {} - _diskByName = {} - _template = None - _meta = None - _sockets = None - _nextAnnouncementTimestamp = 0 - cachePath = None + def __init__(self, filters, meta, template, sockets): + self._filters = None + self._node = None + self._job = None + self._disksByUrl = {} + self._disksById = {} + self._diskByName = {} + self._template = None + self._meta = None + self._sockets = None + self._nextAnnouncementTimestamp = 0 + self.cachePath = None + self.logger = None + + self.logger = Logger("JobRunner", self) self._filters = filters if not isinstance(meta, str): @@ -168,6 +186,9 @@ def __init__(self, filters, meta, template, sockets): if not os.path.exists(self.cachePath): os.makedirs(self.cachePath) + + def getLogger(self,name=None): + return self.logger async def cacheSet(self, path, value, version=0, expireAt=0, local=False): try: @@ -194,7 +215,7 @@ def write_data(): res=await client.cacheSet(write_data()) return res.success except Exception as e: - print("Error setting cache "+str(e)) + self.getLogger().error("Error setting cache "+str(e)) return False @@ -222,7 +243,7 @@ async def cacheGet(self, path, lastVersion = 0, local=False): bytesOut.extend(chunk.data) return pickle.loads(bytesOut) except Exception as e: - print("Error getting cache "+str(e)) + self.getLogger().error("Error getting cache "+str(e)) return None def _setNode(self, node): @@ -235,8 +256,7 @@ def log(self, message): if self._job: message+=" for job "+self._job.id if self._node: self._node.log(message, self._job.id if self._job else None) - else: - print(message) + async def openStorage(self, url): if url in self._disksByUrl: @@ -289,19 +309,23 @@ async def run(self, job): pass class OpenAgentsNode: - nextNodeAnnounce = 0 - nodeName = "" - nodeIcon = "" - nodeDescription = "" - channel = None - rpcClient = None - runners=[] - poolAddress = None - poolPort = None - failedJobsTracker = [] - isLooping = False + def __init__(self, nameOrMeta=None, icon=None, description=None): + self.nextNodeAnnounce = 0 + self.nodeName = "" + self.nodeIcon = "" + self.nodeDescription = "" + self.channel = None + self.rpcClient = None + self.runners=[] + self.poolAddress = None + self.poolPort = None + self.failedJobsTracker = [] + self.isLooping = False + self.logger = None + name = "" + if isinstance(nameOrMeta, str): name = nameOrMeta else : @@ -313,19 +337,24 @@ def __init__(self, nameOrMeta=None, icon=None, description=None): self.nodeDescription = description or os.getenv('NODE_DESCRIPTION', "") self.channel = None self.rpcClient = None + self.logger = Logger(name) def registerRunner(self, runner): + runner.logger=self.logger self.runners.append(runner) + def getLogger(self): + return self.logger def getClient(self): if self.channel is None or self.channel._channel.check_connectivity_state(True) == grpc.ChannelConnectivity.SHUTDOWN: if self.channel is not None: try: + self.getLogger().info("Closing channel") self.channel.close() except Exception as e: - print("Error closing channel "+str(e)) - print("Connect to "+self.poolAddress+":"+str(self.poolPort)+" with ssl "+str(self.poolSsl)) + self.getLogger().error("Error closing channel "+str(e)) + self.getLogger().info("Connect to "+self.poolAddress+":"+str(self.poolPort)+" with ssl "+str(self.poolSsl)) options=[ # 20 MB @@ -340,14 +369,12 @@ def getClient(self): self.rpcClient = rpc_pb2_grpc.PoolConnectorStub(self.channel) return self.rpcClient - async def _log(self, message, jobId=None): + async def _logToJob(self, message, jobId=None): await self.getClient().logForJob(rpc_pb2.RpcJobLog(jobId=jobId, log=message)) def log(self,message, jobId=None): - print(message) if jobId: - #self.getClient().logForJob(rpc_pb2.RpcJobLog(jobId=jobId, log=message)) - asyncio.create_task(self._log(message, jobId)) + asyncio.create_task(self._logToJob(message, jobId)) async def _acceptJob(self, jobId): await self.getClient().acceptJob(rpc_pb2.RpcAcceptJob(jobId=jobId)) @@ -482,4 +509,136 @@ async def run(self, poolAddress=None, poolPort=None, poolSsl=False): while True: await self.executePendingJob() await asyncio.sleep(1000.0/1000.0) - \ No newline at end of file + +class OpenObserveLogger: + + def __init__(self, options): + self.options = options + self.batchSize= self.options["batchSize"] + self.flushInterval = self.options["flushInterval"] + if not self.flushInterval: + self.flushInterval = 5000 + if not self.batchSize: + self.batchSize = 21 + self.buffer = queue.Queue() + self.wait = Condition() + self.flushThread = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self.flushThread.submit(self.flushLoop) + + def log(self, level, message, timestamp=None): + log_entry = { + 'level': level, + '_timestamp': timestamp or int(time.time()*1000), + 'message': message + } + self.buffer.put(log_entry) + if self.buffer.qsize() >= self.batchSize: + with self.wait: + self.wait.notify_all() + + + def flushLoop(self): + while True: + with self.wait: + self.wait.wait(self.flushInterval/1000) + + batch = [] + while not self.buffer.empty(): + batch.append(self.buffer.get()) + + try: + url = self.options["baseUrl"]+"/api/"+self.options["org"]+"/"+"test"+"/_json" + + basicAuth = self.options["auth"] + if not isinstance(basicAuth, str): + if "username" in basicAuth and "password" in basicAuth: + basicAuth = basicAuth["username"]+":"+basicAuth["password"] + basicAuth = base64.b64encode(basicAuth.encode()).decode() + + headers = { + 'Content-Type': 'application/json', + "Authorization": "Basic "+basicAuth if basicAuth else None + } + + res = requests.post(url, headers=headers, json=batch) + if res.status_code != 200: + print("Error flushing log "+str(res.status_code)) + + except Exception as e: + print("Error flushing log "+str(e)) + + + + +class Logger : + + + def __init__(self, name, runner=None): + self.name=name or "main" + self.runner=runner + self.logger=None + self.logLevel=None + self.oobsLogger=None + self.logLevel = os.getenv('LOG_LEVEL', "debug") + oobsEndPoint = os.getenv('OPENOBSERVE_ENDPOINT', None) + self.oobsLogLevel= os.getenv('OPENOBSERVE_LOGLEVEL', self.logLevel) + if oobsEndPoint: + self.oobsLogger = OpenObserveLogger({ + "baseUrl": oobsEndPoint, + "org": os.getenv('OPENOBSERVE_ORG', "default"), + "stream": os.getenv('OPENOBSERVE_STREAM', "default"), + "auth": os.getenv('OPENOBSERVE_BASICAUTH', None) or { + "username": os.getenv('OPENOBSERVE_USERNAME', None), + "password": os.getenv('OPENOBSERVE_PASSWORD', None) + }, + "batchSize": int(os.getenv('OPENOBSERVE_BATCHSIZE', 21)), + "flushInterval": int(os.getenv('OPENOBSERVE_FLUSHINTERVAL', 0)), + + }) + + def levelToValue(self, level): + if level == "error": return 8 + if level == "warn": return 7 + if level == "info": return 6 + if level == "verbose": return 5 + if level == "debug": return 4 + if level == "fine": return 3 + if level == "finer": return 2 + if level == "finest": return 1 + return level + + def _log(self, level, message): + levelV=self.levelToValue(level) + if levelV >=self.levelToValue(self.logLevel): + date = time.strftime("%Y-%m-%d %H:%M:%S") + print(date+" ["+self.name+"] : "+level+" : "+message) + if self.oobsLogger and self.levelToValue(self.oobsLogLevel) >= levelV: + self.oobsLogger.log(level, message) + + + + def log(self, *args): + self._log("debug", " ".join([str(x) for x in args])) + + def info(self, *args): + self._log("info", " ".join([str(x) for x in args])) + + def warn(self, *args): + self._log("warn", " ".join([str(x) for x in args])) + + def error(self, *args): + self._log("error", " ".join([str(x) for x in args])) + + def debug(self, *args): + self._log("debug", " ".join([str(x) for x in args])) + + def fine(self, *args): + self._log("fine", " ".join([str(x) for x in args])) + + def finer(self, *args): + self._log("finer", " ".join([str(x) for x in args])) + + def finest(self, *args): + self._log("finest", " ".join([str(x) for x in args])) + + \ No newline at end of file diff --git a/src/main.py b/src/main.py index 7df6a22..41d1c69 100644 --- a/src/main.py +++ b/src/main.py @@ -13,18 +13,21 @@ import gc class Runner (JobRunner): - INDEXES={} - SEARCH_QUEUE = [] - MAX_MEMORY_CACHE_GB = 1 + def __init__(self, filters, meta, template, sockets): super().__init__(filters, meta, template, sockets) + self.INDEXES={} + self.SEARCH_QUEUE = [] + self.MAX_MEMORY_CACHE_GB = 1 + self.MAX_MEMORY_CACHE_GB = float(os.getenv('MAX_MEMORY_CACHE_GB', self.MAX_MEMORY_CACHE_GB)) + self.getLogger().info("Starting search node") async def deserializeFromBlob(self, url, out_vectors , out_content): blobDisk = await self.openStorage( url) - self.log("Reading embeddings from "+url) + self.getLogger().log("Reading embeddings from "+url) # Find embeddings files sentencesIn = await blobDisk.openReadStream("sentences.bin") @@ -35,6 +38,7 @@ async def deserializeFromBlob(self, url, out_vectors , out_content): nSentences = await sentencesIn.readInt() for i in range(nSentences): + self.getLogger().log("Reading sentence "+str(i)) lenSentence = await sentencesIn.readInt() sentence = await sentencesIn.read(lenSentence) sentence=sentence.decode() @@ -42,6 +46,7 @@ async def deserializeFromBlob(self, url, out_vectors , out_content): nEmbeddings = await embeddingsIn.readInt() for i in range(nEmbeddings): + self.getLogger().log("Reading embeddings "+str(i)) shape = [] lenShape = await embeddingsIn.readInt() for j in range(lenShape): @@ -61,6 +66,7 @@ async def deserializeFromBlob(self, url, out_vectors , out_content): return [dtype,shape] async def deserializeFromJSON( self, data, out_vectors ,out_content): + self.getLogger().log("Reading embeddings from JSON") dtype=None shape=None data=json.loads(data) @@ -115,7 +121,7 @@ async def loop(self ): if len(flattern_queries) == 0: return - self.log("Searching "+str(len(flattern_queries))+" queries") + self.getLogger().info("Searching "+str(len(flattern_queries))+" queries") flattern_queries=np.array(flattern_queries) distances, indices = faiss_index.search(flattern_queries, top_k) for i in range(len(queue)): @@ -146,13 +152,13 @@ def getParamValue(key,default=None): if marker != "query": indexId += jin.data if len(indexId) == 0: - self.log("No index") + self.getLogger().log("No index") return json.dumps([]) indexId=hashlib.sha256(indexId.encode()).hexdigest() index = self.INDEXES.get(indexId) if not index: - self.log("Loading index") + self.getLogger().info("Loading index") index_vectors = [] index_content = [] dtype = None @@ -162,24 +168,26 @@ def getParamValue(key,default=None): continue [dtype,shape] = await self.deserialize(jin,index_vectors ,index_content) - self.log("Preparing index") + self.getLogger().info("Preparing index") index_vectors = np.array(index_vectors) if normalize and dtype == "float32": faiss.normalize_L2(index_vectors) # Create faiss index - self.log("Creating faiss index") + self.getLogger().info("Creating faiss index") faiss_index = faiss.IndexFlatL2(shape[0]) faiss_index.add(index_vectors) + self.getLogger().log("Counting memory usage") indexSizeGB = faiss_index.ntotal * shape[0] * 4 / 1024 / 1024 / 1024 index = [faiss_index, time.time(), index_content, indexSizeGB] self.INDEXES[indexId] = index - + + self.getLogger().log("Dropping oldest indexes if out of memory limit") # drop oldest index if out of memory limit totalSize = sum([x[3] for x in self.INDEXES.values()]) - while totalSize > self.MAX_MEMORY_CACHE_GB: + while totalSize > self.MAX_MEMORY_CACHE_GB and len(self.INDEXES) > 1: oldest = min(self.INDEXES.values(), key=lambda x: x[1]) - self.log("Max cache size reached. Dropping oldest index.") + self.getLogger().log("Max cache size reached. Dropping oldest index.") del self.INDEXES[oldest] totalSize -= oldest[3] gc.collect() @@ -187,26 +195,31 @@ def getParamValue(key,default=None): else: - self.log("Index already loaded") + self.getLogger().info("Index already loaded") index[1] = time.time() + self.getLogger().log("Preparing queries") queries = [] for jin in job.input: if jin.marker == "query": + self.getLogger().log("Preparing query") searches_vectors = [] searches_content = [] [dtype,shape] = await self.deserialize(jin, searches_vectors, searches_content) searches_vectors = np.array(searches_vectors) if normalize and dtype == "float32": + self.getLogger().log("Normalizing") faiss.normalize_L2(searches_vectors) queries=searches_vectors + queries = [ x for x in queries if len(x) > 0] + if len(queries) == 0 : - self.log("No queries") + self.getLogger().log("No queries") return json.dumps([]) # Search faiss index - self.log("Searching") + self.getLogger().info("Searching") search = next((x for x in self.SEARCH_QUEUE if x["indexId"] == indexId), None) if not search: search = { @@ -220,7 +233,7 @@ def getParamValue(key,default=None): future = asyncio.Future() def callback(distances, indices): # Get content for each search query and sort by score - self.log("Retrieving content from index") + self.getLogger().info("Retrieving content from index") output_per_search = [] index_content = index[2] for i in range(len(indices)): @@ -231,6 +244,7 @@ def callback(distances, indices): output_per_search[i] = sorted( output_per_search[i], key=lambda x: x["score"], reverse=False) # Merge results from all searches + self.getLogger().info("Merging search results") output = [] i=0 while len(output) < len(output_per_search)*top_k: @@ -240,6 +254,7 @@ def callback(distances, indices): i+=1 # Remove duplicates + self.getLogger().info("Deduplicating") dedup = [] dedup_ids=[] for o in output: @@ -249,9 +264,11 @@ def callback(distances, indices): output = dedup # truncate output + output = output[:min(top_k, len(output))] future.set_result(output) + self.getLogger().info("Waiting for search results") queue.append([ queries, top_k, @@ -260,6 +277,7 @@ def callback(distances, indices): output = await future # Serialize output and return + self.getLogger().info("Output ready") return json.dumps(output)