diff --git a/src/bisheng-langchain/bisheng_langchain/embeddings/host_embedding.py b/src/bisheng-langchain/bisheng_langchain/embeddings/host_embedding.py index 5655b8c34..05ee82b8d 100644 --- a/src/bisheng-langchain/bisheng_langchain/embeddings/host_embedding.py +++ b/src/bisheng-langchain/bisheng_langchain/embeddings/host_embedding.py @@ -101,16 +101,31 @@ def embed(self, texts: List[str], **kwargs) -> List[List[float]]: if self.verbose: print('payload', inp) + max_text_to_split = 200 outp = None - try: - outp = self.client(url=self.url_ep, json=inp, timeout=self.request_timeout).json() - except requests.exceptions.Timeout: - raise Exception(f'timeout in host embedding infer, url=[{self.url_ep}]') - except Exception as e: - raise Exception(f'exception in host embedding infer: [{e}]') - - if outp['status_code'] != 200: - raise ValueError(f"API returned an error: {outp['status_message']}") + + start_index = 0 + len_text = len(texts) + while start_index < len_text: + inp_local = { + 'texts':texts[start_index:min(start_index + max_text_to_split, len_text)], + 'model':self.model, + 'type':emb_type + } + try: + outp_single = self.client(url=self.url_ep, json=inp_local, timeout=self.request_timeout).json() + if outp is None: + outp = outp_single + else: + outp['embeddings'] += outp_single['embeddings'] + except requests.exceptions.Timeout: + raise Exception(f'timeout in host embedding infer, url=[{self.url_ep}]') + except Exception as e: + raise Exception(f'exception in host embedding infer: [{e}]') + + if outp_single['status_code'] != 200: + raise ValueError(f"API returned an error: {outp['status_message']}") + start_index += max_text_to_split return outp['embeddings'] def embed_documents(self,