Skip to content

Commit

Permalink
fix: pass lint
Browse files Browse the repository at this point in the history
Signed-off-by: OxalisCu <2127298698@qq.com>
  • Loading branch information
OxalisCu committed Sep 30, 2024
1 parent 90177cb commit 75f84ba
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import csv
import json
import logging
from pathlib import Path
Expand Down Expand Up @@ -290,27 +289,35 @@ def _persist_parquet(self, local_path: str, **kwargs):

def _persist_csv(self, local_path: str, **kwargs):
sep = self._config.get("sep", ",")
nullkey = self._config.get("nullkey", "")
# nullkey is not supported in csv now

header = list(self._buffer.keys())
df = pd.DataFrame(columns=header)
data = pd.DataFrame(columns=header)
for k, v in self._buffer.items():
field_schema = self._fields[k]
# When using df.to_csv(arr) to write non-scalar data, the repr function is used by default to convert the data to a string.
# if the value of arr is [1.0, 2.0], repr(arr) will change with the type of arr, making things complicated:
# when arr is a list, the output is '[1.0, 2.0]'
# when arr is a tuple, the output is '(1.0, 2.0)'
# when arr is a np.array, the output is '[1.0 2.0]'
# When using df.to_csv(arr) to write non-scalar data,
# the repr function is used to convert the data to a string.
# if the value of arr is [1.0, 2.0], repr(arr) will change with the type of arr:
# when arr is a list, the output is '[1.0, 2.0]'
# when arr is a tuple, the output is '(1.0, 2.0)'
# when arr is a np.array, the output is '[1.0 2.0]'
# we needs the output to be '[1.0, 2.0]', consistent with the array format in json
# so 1. whether make sure that arr of type (BINARY_VECTOR, FLOAT_VECTOR, FLOAT16_VECTOR, BFLOAT16_VECTOR) is a LIST,
# so 1. whether make sure that arr of type
# (BINARY_VECTOR, FLOAT_VECTOR, FLOAT16_VECTOR, BFLOAT16_VECTOR) is a LIST,
# 2. or convert arr into a string using json.dumps(arr) first and then add it to df
# I choose method 2 here
if field_schema.dtype in {DataType.JSON, DataType.ARRAY, DataType.SPARSE_FLOAT_VECTOR, DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}:
if field_schema.dtype in {
DataType.JSON,
DataType.ARRAY,
DataType.SPARSE_FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
}:
dt = np.dtype("str")
arr = []
for val in v:
arr.append(json.dumps(val))
df[k] = pd.Series(arr, dtype=dt)
data[k] = pd.Series(arr, dtype=dt)
elif field_schema.dtype in {DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR}:
# special process for float16 vector, the self._buffer stores bytes for
# float16 vector, convert the bytes to float list
Expand All @@ -321,24 +328,23 @@ def _persist_csv(self, local_path: str, **kwargs):
)
arr = []
for val in v:
val = json.dumps(np.frombuffer(val, dtype=dt).tolist())
arr.append(val)
df[k] = pd.Series(arr, dtype=np.dtype("str"))
arr.append(json.dumps(np.frombuffer(val, dtype=dt).tolist()))
data[k] = pd.Series(arr, dtype=np.dtype("str"))
elif field_schema.dtype in {DataType.BOOL}:
dt = np.dtype("str")
arr = ["true" if x else "false" for x in v]
df[k] = pd.Series(arr, dtype=dt)
data[k] = pd.Series(arr, dtype=dt)
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]
df[k] = pd.Series(v, dtype=dt)
data[k] = pd.Series(v, dtype=dt)
else:
df[k] = pd.Series(v)
data[k] = pd.Series(v)

file_path = Path(local_path + ".csv")
try:
df.to_csv(file_path, sep=sep, index=False)
data.to_csv(file_path, sep=sep, index=False)
except Exception as e:
self._throw(f"Failed to persist file {file_path}, error: {e}")

logger.info("Successfully persist file %s, row count: %s", file_path, len(df))
logger.info("Successfully persist file %s, row count: %s", file_path, len(data))
return [str(file_path)]

0 comments on commit 75f84ba

Please sign in to comment.