Skip to content

Commit

Permalink
refactored splitTrainVal and added multiOS path support (#11069)
Browse files Browse the repository at this point in the history
  • Loading branch information
itasli authored Oct 13, 2023
1 parent d0d77fe commit 2213807
Showing 1 changed file with 38 additions and 43 deletions.
81 changes: 38 additions & 43 deletions PPOCRLabel/gen_ocr_train_val_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,43 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath


def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集
dataAbsPath = os.path.abspath(root)

if flag == "det":
labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
elif flag == "rec":
labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)

labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines()
random.shuffle(labelFileContent)
labelRecordLen = len(labelFileContent)

for index, labelRecordInfo in enumerate(labelFileContent):
imageRelativePath = labelRecordInfo.split('\t')[0]
imageLabel = labelRecordInfo.split('\t')[1]
imageName = os.path.basename(imageRelativePath)

if flag == "det":
imagePath = os.path.join(dataAbsPath, imageName)
elif flag == "rec":
imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))

# 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":")
trainRatio = eval(trainValTestRatio[0]) / 10
valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
curRatio = index / labelRecordLen

if curRatio < trainRatio:
imageCopyPath = os.path.join(absTrainRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
elif curRatio >= trainRatio and curRatio < valRatio:
imageCopyPath = os.path.join(absValRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
else:
imageCopyPath = os.path.join(absTestRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):

data_abs_path = os.path.abspath(root)
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
label_file_path = os.path.join(data_abs_path, label_file_name)

with open(label_file_path, "r", encoding="UTF-8") as label_file:
label_file_content = label_file.readlines()
random.shuffle(label_file_content)
label_record_len = len(label_file_content)

for index, label_record_info in enumerate(label_file_content):
image_relative_path, image_label = label_record_info.split('\t')
image_name = os.path.basename(image_relative_path)

if flag == "det":
image_path = os.path.join(data_abs_path, image_name)
elif flag == "rec":
image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)

train_val_test_ratio = args.trainValTestRatio.split(":")
train_ratio = eval(train_val_test_ratio[0]) / 10
val_ratio = train_ratio + eval(train_val_test_ratio[1]) / 10
cur_ratio = index / label_record_len

if cur_ratio < train_ratio:
image_copy_path = os.path.join(abs_train_root_path, image_name)
shutil.copy(image_path, image_copy_path)
train_txt.write("{}\t{}\n".format(image_copy_path, image_label))
elif cur_ratio >= train_ratio and cur_ratio < val_ratio:
image_copy_path = os.path.join(abs_val_root_path, image_name)
shutil.copy(image_path, image_copy_path)
val_txt.write("{}\t{}\n".format(image_copy_path, image_label))
else:
image_copy_path = os.path.join(abs_test_root_path, image_name)
shutil.copy(image_path, image_copy_path)
test_txt.write("{}\t{}\n".format(image_copy_path, image_label))


# 删掉存在的文件
Expand Down Expand Up @@ -148,4 +143,4 @@ def genDetRecTrainVal(args):
help="the name of the folder where the cropped recognition dataset is located"
)
args = parser.parse_args()
genDetRecTrainVal(args)
genDetRecTrainVal(args)

0 comments on commit 2213807

Please sign in to comment.