Skip to content

Commit

Permalink
Adding CTC compatiable models
Browse files Browse the repository at this point in the history
  • Loading branch information
tulasiram58827 committed Dec 21, 2020
1 parent 5d20eee commit a601cc0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 52 deletions.
57 changes: 5 additions & 52 deletions colabs/KERAS_OCR_TFLITE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -38,23 +38,7 @@
"id": "5Mme4MUCxVM8",
"outputId": "df988251-c816-489b-fa43-58a1096500a3"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'2.5.0-dev20201128'"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import typing\n",
"import string\n",
Expand Down Expand Up @@ -404,8 +388,8 @@
" activation='softmax',\n",
" name='fc_12')(x)\n",
" x = keras.layers.Lambda(lambda x: x[:, rnn_steps_to_discard:])(x)\n",
" prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))\n",
" model = keras.models.Model(inputs=inputs, outputs=x)\n",
" prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))\n",
" return model, prediction_model"
]
},
Expand Down Expand Up @@ -611,9 +595,7 @@
"\n",
"Refer to this [issue](https://github.com/tensorflow/tensorflow/issues/33494) regarding CTC decoder support in TFLite. \n",
"\n",
"This is the code for [CTC DECODER](https://colab.research.google.com/github/tulasiram58827/ocr_tflite/blob/main/colabs/KERAS_OCR_TFLITE.ipynb#scrollTo=_rdJyCXo2Xzs). By default it is greedy Decoder we can also use Beam Search Decoder by specifying the parameter in the `ctc_decode` function.\n",
"\n",
"FYI: I am also working on converting Beam Search CTC Decoder to low level language so that we can port entire OCR and use it as offline application combinedly with EAST/CRAFT."
"**Update** : CTC Decoder is supported in TFLite now by enabling Built-in-Ops in Tensorflow 2.4. Thanks to TensorFlow team for the support."
]
},
{
Expand Down Expand Up @@ -794,29 +776,6 @@
"## TFLite Inference"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "5bxtGxQEf0LJ"
},
"outputs": [],
"source": [
"# Code for CTC Decoder \n",
"\n",
"def decoder(y_pred):\n",
" input_shape = tf.keras.backend.shape(y_pred)\n",
" input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(\n",
" input_shape[1], 'float32')\n",
" # You can turn on beam search decoding using greedy=False and also play with beam_width parameter.\n",
" unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]\n",
" unpadded_shape = tf.keras.backend.shape(unpadded)\n",
" padded = tf.pad(unpadded,\n",
" paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],\n",
" constant_values=-1)\n",
" return padded"
]
},
{
"cell_type": "code",
"execution_count": 25,
Expand Down Expand Up @@ -895,8 +854,6 @@
"source": [
"# Running Dynamic Range Quantization\n",
"tflite_output = run_tflite_model(image_path, 'dr')\n",
"# Running decoder on TFLite Output\n",
"# decoded = decoder(tflite_output)\n",
"final_output = \"\".join(alphabets[index] for index in tflite_output[0] if index not in [blank_index, -1])\n",
"print(final_output)\n",
"cv2_imshow(cv2.imread(image_path))"
Expand Down Expand Up @@ -937,8 +894,6 @@
"source": [
"# Running Float16 Quantization\n",
"tflite_output = run_tflite_model(image_path, 'float16')\n",
"# Running decoder on TFLite Output\n",
"# decoded = decoder(tflite_output)\n",
"final_output = \"\".join(alphabets[index] for index in tflite_output[0] if index not in [blank_index, -1])\n",
"print(final_output)\n",
"cv2_imshow(cv2.imread(image_path))"
Expand All @@ -954,8 +909,6 @@
"source": [
"# Running Integer Quantization\n",
"tflite_output = run_tflite_model(image_path, 'int8')\n",
"# Running decoder on TFLite Output\n",
"decoded = decoder(tflite_output)\n",
"final_output = \"\".join(alphabets[index] for index in decoded[0] if index not in [blank_index, -1])\n",
"print(final_output)\n",
"cv2_imshow(cv2.imread(image_path))"
Expand Down Expand Up @@ -1020,7 +973,7 @@
"id": "1g_bWIW5gvni"
},
"source": [
"**The above benchmarks with respect to Redmi K20 Pro**"
"**The above benchmarks with respect to Redmi K20 Pro with 4 threads. **"
]
}
],
Expand Down
Binary file added models/keras_ocr_dr_ctc.tflite
Binary file not shown.
Binary file added models/keras_ocr_float16_ctc.tflite
Binary file not shown.

0 comments on commit a601cc0

Please sign in to comment.