Skip to content

Commit

Permalink
Merge branch 'dygraph' into dygraph
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyutang authored Aug 21, 2023
2 parents f893a78 + 2bd552c commit 6c6de3f
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@
description: Format files with ClangFormat
entry: bash .clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
16 changes: 16 additions & 0 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def parse_args(mMain=True):
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr')
parser.add_argument("--savefile", type=str2bool, default=False)
parser.add_argument(
"--ocr_version",
type=str,
Expand Down Expand Up @@ -794,10 +795,25 @@ def main():
alpha_color=args.alphacolor
)
if result is not None:
lines = []
for idx in range(len(result)):
res = result[idx]
for line in res:
logger.info(line)
val = '['
for box in line[0]:
val += str(box[0]) + ',' + str(box[1]) + ','

val = val[:-1]
val += '],' + line[1][0] + ',' + str(line[1][1]) + '\n'
lines.append(val)
if args.savefile:
if os.path.exists(args.output) is False:
os.mkdir(args.output)
outfile = args.output + '/' + img_name + '.txt'
with open(outfile,'w',encoding='utf-8') as f:
f.writelines(lines)

elif args.type == 'structure':
img, flag_gif, flag_pdf = check_and_read(img_path)
if not flag_gif and not flag_pdf:
Expand Down
10 changes: 8 additions & 2 deletions ppocr/data/simple_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ def __init__(self, config, mode, logger, seed=None):
def set_epoch_as_seed(self, seed, dataset_config):
if self.mode == 'train':
try:
dataset_config['transforms'][5]['MakeBorderMap'][
border_map_id = [index
for index, dictionary in enumerate(dataset_config['transforms'])
if 'MakeBorderMap' in dictionary][0]
shrink_map_id = [index
for index, dictionary in enumerate(dataset_config['transforms'])
if 'MakeShrinkMap' in dictionary][0]
dataset_config['transforms'][border_map_id]['MakeBorderMap'][
'epoch'] = seed if seed is not None else 0
dataset_config['transforms'][6]['MakeShrinkMap'][
dataset_config['transforms'][shrink_map_id]['MakeShrinkMap'][
'epoch'] = seed if seed is not None else 0
except Exception as E:
print(E)
Expand Down
13 changes: 3 additions & 10 deletions ppocr/modeling/backbones/vqa_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,11 @@ def __init__(self,
if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints)
else: # load the pretrained-model
pretrained_model_name = pretrained_model_dict[base_model_class][
mode]
if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class.from_pretrained(pretrained)
pretrained_model_name = pretrained_model_dict[base_model_class][mode]
if type == "ser":
self.model = model_class(
base_model, num_classes=kwargs["num_classes"], dropout=None)
self.model = model_class.from_pretrained(pretrained_model_name, num_classes=kwargs["num_classes"], dropout=0)
else:
self.model = model_class(base_model, dropout=None)
self.model = model_class.from_pretrained(pretrained_model_name, dropout=0)
self.out_channels = 1
self.use_visual_backbone = True

Expand Down
5 changes: 4 additions & 1 deletion tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def create_predictor(args, mode, logger):
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
sess = ort.InferenceSession(model_file_path)
if args.use_gpu:
sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else:
sess = ort.InferenceSession(model_file_path)
return sess, sess.get_inputs()[0], None, None

else:
Expand Down

0 comments on commit 6c6de3f

Please sign in to comment.