diff --git a/PPOCRLabel/gen_ocr_train_val_test.py b/PPOCRLabel/gen_ocr_train_val_test.py index 03ae566c6e..a48053fa87 100644 --- a/PPOCRLabel/gen_ocr_train_val_test.py +++ b/PPOCRLabel/gen_ocr_train_val_test.py @@ -17,48 +17,43 @@ def isCreateOrDeleteFolder(path, flag): return flagAbsPath -def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): - # 按照指定的比例划分训练集、验证集、测试集 - dataAbsPath = os.path.abspath(root) - - if flag == "det": - labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName) - elif flag == "rec": - labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName) - - labelFileRead = open(labelFilePath, "r", encoding="UTF-8") - labelFileContent = labelFileRead.readlines() - random.shuffle(labelFileContent) - labelRecordLen = len(labelFileContent) - - for index, labelRecordInfo in enumerate(labelFileContent): - imageRelativePath = labelRecordInfo.split('\t')[0] - imageLabel = labelRecordInfo.split('\t')[1] - imageName = os.path.basename(imageRelativePath) - - if flag == "det": - imagePath = os.path.join(dataAbsPath, imageName) - elif flag == "rec": - imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) - - # 按预设的比例划分训练集、验证集、测试集 - trainValTestRatio = args.trainValTestRatio.split(":") - trainRatio = eval(trainValTestRatio[0]) / 10 - valRatio = trainRatio + eval(trainValTestRatio[1]) / 10 - curRatio = index / labelRecordLen - - if curRatio < trainRatio: - imageCopyPath = os.path.join(absTrainRootPath, imageName) - shutil.copy(imagePath, imageCopyPath) - trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) - elif curRatio >= trainRatio and curRatio < valRatio: - imageCopyPath = os.path.join(absValRootPath, imageName) - shutil.copy(imagePath, imageCopyPath) - valTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) - else: - imageCopyPath = os.path.join(absTestRootPath, imageName) - shutil.copy(imagePath, imageCopyPath) - testTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) +def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag): + + data_abs_path = os.path.abspath(root) + label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName + label_file_path = os.path.join(data_abs_path, label_file_name) + + with open(label_file_path, "r", encoding="UTF-8") as label_file: + label_file_content = label_file.readlines() + random.shuffle(label_file_content) + label_record_len = len(label_file_content) + + for index, label_record_info in enumerate(label_file_content): + image_relative_path, image_label = label_record_info.split('\t') + image_name = os.path.basename(image_relative_path) + + if flag == "det": + image_path = os.path.join(data_abs_path, image_name) + elif flag == "rec": + image_path = os.path.join(data_abs_path, args.recImageDirName, image_name) + + train_val_test_ratio = args.trainValTestRatio.split(":") + train_ratio = eval(train_val_test_ratio[0]) / 10 + val_ratio = train_ratio + eval(train_val_test_ratio[1]) / 10 + cur_ratio = index / label_record_len + + if cur_ratio < train_ratio: + image_copy_path = os.path.join(abs_train_root_path, image_name) + shutil.copy(image_path, image_copy_path) + train_txt.write("{}\t{}\n".format(image_copy_path, image_label)) + elif cur_ratio >= train_ratio and cur_ratio < val_ratio: + image_copy_path = os.path.join(abs_val_root_path, image_name) + shutil.copy(image_path, image_copy_path) + val_txt.write("{}\t{}\n".format(image_copy_path, image_label)) + else: + image_copy_path = os.path.join(abs_test_root_path, image_name) + shutil.copy(image_path, image_copy_path) + test_txt.write("{}\t{}\n".format(image_copy_path, image_label)) # 删掉存在的文件 @@ -148,4 +143,4 @@ def genDetRecTrainVal(args): help="the name of the folder where the cropped recognition dataset is located" ) args = parser.parse_args() - genDetRecTrainVal(args) + genDetRecTrainVal(args) \ No newline at end of file