66import logging
77import time
88import traceback
9- from dataclasses import dataclass , asdict
10- from enum import Enum
119from pathlib import Path
12- from typing import List , Optional , Union , Dict , Any
10+ from typing import List , Optional , Tuple , Union , Dict , Any
1311import numpy as np
1412import cv2
1513
16- from wired_table_rec .table_structure_cycle_center_net import TSRCycleCenterNet
17- from wired_table_rec .table_structure_unet import TSRUnet
18- from wired_table_rec .utils .download_model import DownloadModel
14+ from wired_table_rec .table_line_rec import TableLineRecognition
15+ from wired_table_rec .table_line_rec_plus import TableLineRecognitionPlus
1916from .table_recover import TableRecover
20- from .utils . utils import InputType , LoadImage
21- from wired_table_rec . utils .utils_table_recover import (
17+ from .utils import InputType , LoadImage
18+ from .utils_table_recover import (
2219 match_ocr_cell ,
2320 plot_html_table ,
2421 box_4_2_poly_to_box_4_1 ,
2724 gather_ocr_list_by_row ,
2825)
2926
30-
31- class ModelType (Enum ):
32- CYCLE_CENTER_NET = "cycle_center_net"
33- UNET = "unet"
34-
35-
36- ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
37- KEY_TO_MODEL_URL = {
38- ModelType .CYCLE_CENTER_NET .value : f"{ ROOT_URL } /cycle_center_net.onnx" ,
39- ModelType .UNET .value : f"{ ROOT_URL } /unet.onnx" ,
40- }
41-
42-
43- @dataclass
44- class WiredTableInput :
45- model_type : Optional [str ] = ModelType .UNET .value
46- model_path : Union [str , Path , None , Dict [str , str ]] = None
47- use_cuda : bool = False
48- device : str = "cpu"
49-
50-
51- @dataclass
52- class WiredTableOutput :
53- pred_html : Optional [str ] = None
54- cell_bboxes : Optional [np .ndarray ] = None
55- logic_points : Optional [np .ndarray ] = None
56- elapse : Optional [float ] = None
27+ cur_dir = Path (__file__ ).resolve ().parent
28+ default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
29+ default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"
5730
5831
5932class WiredTableRecognition :
60- def __init__ (self , config : WiredTableInput ):
61- self .model_type = config .model_type
62- if self .model_type not in KEY_TO_MODEL_URL :
63- model_list = "," .join (KEY_TO_MODEL_URL )
64- raise ValueError (
65- f"{ self .model_type } is not supported. The currently supported models are { model_list } ."
66- )
67-
68- config .model_path = self .get_model_path (config .model_type , config .model_path )
69- if self .model_type == ModelType .CYCLE_CENTER_NET .value :
70- self .table_structure = TSRCycleCenterNet (asdict (config ))
71- else :
72- self .table_structure = TSRUnet (asdict (config ))
73-
33+ def __init__ (self , table_model_path : Union [str , Path ] = None , version = "v2" ):
7434 self .load_img = LoadImage ()
35+ if version == "v2" :
36+ model_path = table_model_path if table_model_path else default_model_path_v2
37+ self .table_line_rec = TableLineRecognitionPlus (str (model_path ))
38+ else :
39+ model_path = table_model_path if table_model_path else default_model_path
40+ self .table_line_rec = TableLineRecognition (str (model_path ))
7541
7642 self .table_recover = TableRecover ()
7743
44+ try :
45+ self .ocr = importlib .import_module ("rapidocr_onnxruntime" ).RapidOCR ()
46+ except ModuleNotFoundError :
47+ self .ocr = None
48+
7849 def __call__ (
7950 self ,
8051 img : InputType ,
8152 ocr_result : Optional [List [Union [List [List [float ]], str , str ]]] = None ,
8253 ** kwargs ,
83- ) -> WiredTableOutput :
54+ ) -> Tuple [str , float , Any , Any , Any ]:
55+ if self .ocr is None and ocr_result is None :
56+ raise ValueError (
57+ "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
58+ )
59+
8460 s = time .perf_counter ()
61+ rec_again = True
8562 need_ocr = True
8663 col_threshold = 15
8764 row_threshold = 10
8865 if kwargs :
66+ rec_again = kwargs .get ("rec_again" , True )
8967 need_ocr = kwargs .get ("need_ocr" , True )
9068 col_threshold = kwargs .get ("col_threshold" , 15 )
9169 row_threshold = kwargs .get ("row_threshold" , 10 )
9270 img = self .load_img (img )
93- polygons , rotated_polygons = self .table_structure (img , ** kwargs )
71+ polygons , rotated_polygons = self .table_line_rec (img , ** kwargs )
9472 if polygons is None :
9573 logging .warning ("polygons is None." )
96- return WiredTableOutput ( "" , None , None , 0.0 )
74+ return "" , 0.0 , None , None , None
9775
9876 try :
9977 table_res , logi_points = self .table_recover (
@@ -108,34 +86,52 @@ def __call__(
10886 sorted_polygons , idx_list = sorted_ocr_boxes (
10987 [box_4_2_poly_to_box_4_1 (box ) for box in polygons ]
11088 )
111- return WiredTableOutput (
89+ return (
11290 "" ,
91+ time .perf_counter () - s ,
11392 sorted_polygons ,
11493 logi_points [idx_list ],
115- time . perf_counter () - s ,
94+ [] ,
11695 )
96+ if ocr_result is None and need_ocr :
97+ ocr_result , _ = self .ocr (img )
11798 cell_box_det_map , not_match_orc_boxes = match_ocr_cell (ocr_result , polygons )
11899 # 如果有识别框没有ocr结果,直接进行rec补充
119- cell_box_det_map = self .fill_blank_rec (img , polygons , cell_box_det_map )
100+ cell_box_det_map = self .re_rec (img , polygons , cell_box_det_map , rec_again )
120101 # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
121- t_rec_ocr_list = self .transform_res (cell_box_det_map , polygons , logi_points )
102+ t_rec_ocr_list_dict = self .transform_res (cell_box_det_map , polygons , logi_points )
103+ # 第一行或者第一列为空时,调整代码
104+ #adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)
105+ adjust_dict = self .process_ocr_result (t_rec_ocr_list_dict )
122106 # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
123- t_rec_ocr_list = self .sort_and_gather_ocr_res (t_rec_ocr_list )
107+ t_rec_ocr_list = self .sort_and_gather_ocr_res (t_rec_ocr_list_dict )
124108 # cell_box_map =
125109 logi_points = [t_box_ocr ["t_logic_box" ] for t_box_ocr in t_rec_ocr_list ]
126110 cell_box_det_map = {
127111 i : [ocr_box_and_text [1 ] for ocr_box_and_text in t_box_ocr ["t_ocr_res" ]]
128112 for i , t_box_ocr in enumerate (t_rec_ocr_list )
129113 }
130- pred_html = plot_html_table (logi_points , cell_box_det_map )
131- polygons = np .array (polygons ).reshape (- 1 , 8 )
132- logi_points = np .array (logi_points )
133- elapse = time .perf_counter () - s
114+ table_str = plot_html_table (logi_points , cell_box_det_map )
115+ ocr_boxes_res = [
116+ box_4_2_poly_to_box_4_1 (ori_ocr [0 ]) for ori_ocr in ocr_result
117+ ]
118+ sorted_ocr_boxes_res , _ = sorted_ocr_boxes (ocr_boxes_res )
119+ sorted_polygons = [box_4_2_poly_to_box_4_1 (box ) for box in polygons ]
120+ sorted_logi_points = logi_points
121+ table_elapse = time .perf_counter () - s
134122
135123 except Exception :
136124 logging .warning (traceback .format_exc ())
137- return WiredTableOutput ("" , None , None , 0.0 )
138- return WiredTableOutput (pred_html , polygons , logi_points , elapse )
125+ return "" , 0.0 , None , None , None
126+ return (
127+ table_str ,
128+ table_elapse ,
129+ sorted_polygons ,
130+ sorted_logi_points ,
131+ sorted_ocr_boxes_res ,
132+ adjust_dict
133+
134+ )
139135
140136 def transform_res (
141137 self ,
@@ -166,6 +162,102 @@ def transform_res(
166162 res .append (dict_res )
167163 return res
168164
165+ def process_ocr_result (self , ocr_result ):
166+ # 删除第一行的字典,并调整其余字典的行数
167+ first_row_empty = [entry for entry in ocr_result if
168+ entry ['t_logic_box' ][0 ] == 0 and entry ['t_logic_box' ][1 ] == 0 and entry ['t_ocr_res' ][0 ][
169+ 1 ] == '' ]
170+
171+ if len (first_row_empty ) == len (
172+ [entry for entry in ocr_result if entry ['t_logic_box' ][0 ] == 0 and entry ['t_logic_box' ][1 ] == 0 ]):
173+ # 如果第一行的所有单元格都为空,删除第一行
174+ ocr_result = [entry for entry in ocr_result if entry ['t_logic_box' ][0 ] != 0 or entry ['t_logic_box' ][1 ] != 0 ]
175+ # 调整剩余字典的行数
176+ for entry in ocr_result :
177+ entry ['t_logic_box' ][0 ] -= 1
178+ entry ['t_logic_box' ][1 ] -= 1
179+
180+ # 删除第一列的字典,并调整其余字典的列数
181+ first_col_empty = [entry for entry in ocr_result if
182+ entry ['t_logic_box' ][2 ] == 0 and entry ['t_logic_box' ][3 ] == 0 and entry ['t_ocr_res' ][0 ][
183+ 1 ] == '' ]
184+
185+ if len (first_col_empty ) == len (
186+ [entry for entry in ocr_result if entry ['t_logic_box' ][2 ] == 0 and entry ['t_logic_box' ][3 ] == 0 ]):
187+ # 如果第一列的所有单元格都为空,删除第一列
188+ ocr_result = [entry for entry in ocr_result if entry ['t_logic_box' ][2 ] != 0 or entry ['t_logic_box' ][3 ] != 0 ]
189+ # 调整剩余字典的列数
190+ for entry in ocr_result :
191+ entry ['t_logic_box' ][2 ] -= 1
192+ entry ['t_logic_box' ][3 ] -= 1
193+
194+ return ocr_result
195+
196+ def adjust_table_cells (self , t_rec_ocr_list_dict ):
197+ """
198+ 调整表格单元格,去掉第一行和/或第一列的单元格,
199+ 并更新剩余单元格的行列起始和结束位置。
200+
201+ 参数:
202+ t_rec_ocr_list_dict (list): 原始表格单元格识别结果,格式为
203+ [
204+ {
205+ "t_box": [xmin, ymin, xmax, ymax],
206+ "t_logic_box": [row_start, row_end, col_start, col_end],
207+ "t_ocr_res": [[box, text], ...]
208+ },
209+ ...
210+ ]
211+
212+ 返回:
213+ list: 调整后的表格单元格识别结果,格式与输入相同。
214+ """
215+ # 新的结果列表
216+ adjusted_result = []
217+
218+ # 记录是否第一行和第一列的单元格已被删除
219+ remove_first_row = False
220+ remove_first_col = False
221+
222+ # 检查并移除第一行
223+ if all (cell and not cell [1 ] for cell in t_rec_ocr_list_dict [0 ].get ("t_ocr_res" , [])):
224+ remove_first_row = True
225+
226+ # 检查并移除第一列
227+ if all (row .get ("t_ocr_res" ) and not row ["t_ocr_res" ][0 ][1 ] for row in t_rec_ocr_list_dict ):
228+ remove_first_col = True
229+
230+ # 遍历原始结果进行调整
231+ for i , row in enumerate (t_rec_ocr_list_dict ):
232+ adjusted_row = []
233+
234+ # 如果是第一行并且需要删除,跳过这行
235+ if remove_first_row and i == 0 :
236+ continue
237+
238+ for j , cell in enumerate (row .get ("t_ocr_res" , [])):
239+ # 如果是第一列并且需要删除,跳过这一列
240+ if remove_first_col and j == 0 :
241+ continue
242+
243+ # 更新当前单元格的逻辑位置
244+ adjusted_cell = {
245+ "t_box" : row .get ("t_box" ),
246+ "t_logic_box" : [
247+ row ["t_logic_box" ][0 ] - 1 if i > 0 else row ["t_logic_box" ][0 ],
248+ row ["t_logic_box" ][1 ] - 1 if i > 0 else row ["t_logic_box" ][1 ],
249+ row ["t_logic_box" ][2 ] - 1 if j > 0 else row ["t_logic_box" ][2 ],
250+ row ["t_logic_box" ][3 ] - 1 if j > 0 else row ["t_logic_box" ][3 ]
251+ ],
252+ "t_ocr_res" : cell
253+ }
254+ adjusted_row .append (adjusted_cell )
255+
256+ if adjusted_row :
257+ adjusted_result .append (adjusted_row )
258+
259+ return adjusted_result
260+
169261 def sort_and_gather_ocr_res (self , res ):
170262 for i , dict_res in enumerate (res ):
171263 _ , sorted_idx = sorted_ocr_boxes (
@@ -177,19 +269,30 @@ def sort_and_gather_ocr_res(self, res):
177269 )
178270 return res
179271
180- def fill_blank_rec (
272+ def re_rec (
181273 self ,
182274 img : np .ndarray ,
183275 sorted_polygons : np .ndarray ,
184276 cell_box_map : Dict [int , List [str ]],
277+ rec_again = True ,
185278 ) -> Dict [int , List [Any ]]:
186279 """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
187280 for i in range (sorted_polygons .shape [0 ]):
188281 if cell_box_map .get (i ):
189282 continue
283+ if not rec_again :
284+ box = sorted_polygons [i ]
285+ cell_box_map [i ] = [[box , "" , 1 ]]
286+ continue
287+ crop_img = get_rotate_crop_image (img , sorted_polygons [i ])
288+ pad_img = cv2 .copyMakeBorder (
289+ crop_img , 5 , 5 , 100 , 100 , cv2 .BORDER_CONSTANT , value = (255 , 255 , 255 )
290+ )
291+ rec_res , _ = self .ocr (pad_img , use_det = False , use_cls = True , use_rec = True )
190292 box = sorted_polygons [i ]
191- cell_box_map [i ] = [[box , "" , 1 ]]
192- continue
293+ text = [rec [0 ] for rec in rec_res ]
294+ scores = [rec [1 ] for rec in rec_res ]
295+ cell_box_map [i ] = [[box , "" .join (text ), min (scores )]]
193296 return cell_box_map
194297
195298 def re_rec_high_precise (
@@ -222,46 +325,24 @@ def re_rec_high_precise(
222325 ]
223326 return cell_box_map
224327
225- @staticmethod
226- def get_model_path (
227- model_type : str , model_path : Union [str , Path , None ]
228- ) -> Union [str , Dict [str , str ]]:
229- if model_path is not None :
230- return model_path
231-
232- model_url = KEY_TO_MODEL_URL .get (model_type , None )
233- if isinstance (model_url , str ):
234- model_path = DownloadModel .download (model_url )
235- return model_path
236-
237- if isinstance (model_url , dict ):
238- model_paths = {}
239- for k , url in model_url .items ():
240- model_paths [k ] = DownloadModel .download (
241- url , save_model_name = f"{ model_type } _{ Path (url ).name } "
242- )
243- return model_paths
244-
245- raise ValueError (f"Model URL: { type (model_url )} is not between str and dict." )
246-
247328
248329def main ():
249330 parser = argparse .ArgumentParser ()
250331 parser .add_argument ("-img" , "--img_path" , type = str , required = True )
251332 args = parser .parse_args ()
252333
253334 try :
254- ocr_engine = importlib .import_module ("rapidocr " ).RapidOCR ()
335+ ocr_engine = importlib .import_module ("rapidocr_onnxruntime " ).RapidOCR ()
255336 except ModuleNotFoundError as exc :
256337 raise ModuleNotFoundError (
257- "Please install the rapidocr by pip install rapidocr ."
338+ "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime ."
258339 ) from exc
259- input_args = WiredTableInput ()
260- table_rec = WiredTableRecognition (input_args )
340+
341+ table_rec = WiredTableRecognition ()
261342 ocr_result , _ = ocr_engine (args .img_path )
262- table_results = table_rec (args .img_path , ocr_result )
263- print (table_results . pred_html )
264- print (f"cost: { table_results . elapse :.5f} " )
343+ table_str , elapse = table_rec (args .img_path , ocr_result )
344+ print (table_str )
345+ print (f"cost: { elapse :.5f} " )
265346
266347
267348if __name__ == "__main__" :
0 commit comments