Skip to content

Commit 9aed350

Browse files
authored
Update main.py
可返回json格式结果
1 parent 1c7cbd0 commit 9aed350

1 file changed

Lines changed: 176 additions & 95 deletions

File tree

wired_table_rec/main.py

Lines changed: 176 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
import logging
77
import time
88
import traceback
9-
from dataclasses import dataclass, asdict
10-
from enum import Enum
119
from pathlib import Path
12-
from typing import List, Optional, Union, Dict, Any
10+
from typing import List, Optional, Tuple, Union, Dict, Any
1311
import numpy as np
1412
import cv2
1513

16-
from wired_table_rec.table_structure_cycle_center_net import TSRCycleCenterNet
17-
from wired_table_rec.table_structure_unet import TSRUnet
18-
from wired_table_rec.utils.download_model import DownloadModel
14+
from wired_table_rec.table_line_rec import TableLineRecognition
15+
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
1916
from .table_recover import TableRecover
20-
from .utils.utils import InputType, LoadImage
21-
from wired_table_rec.utils.utils_table_recover import (
17+
from .utils import InputType, LoadImage
18+
from .utils_table_recover import (
2219
match_ocr_cell,
2320
plot_html_table,
2421
box_4_2_poly_to_box_4_1,
@@ -27,73 +24,54 @@
2724
gather_ocr_list_by_row,
2825
)
2926

30-
31-
class ModelType(Enum):
32-
CYCLE_CENTER_NET = "cycle_center_net"
33-
UNET = "unet"
34-
35-
36-
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
37-
KEY_TO_MODEL_URL = {
38-
ModelType.CYCLE_CENTER_NET.value: f"{ROOT_URL}/cycle_center_net.onnx",
39-
ModelType.UNET.value: f"{ROOT_URL}/unet.onnx",
40-
}
41-
42-
43-
@dataclass
44-
class WiredTableInput:
45-
model_type: Optional[str] = ModelType.UNET.value
46-
model_path: Union[str, Path, None, Dict[str, str]] = None
47-
use_cuda: bool = False
48-
device: str = "cpu"
49-
50-
51-
@dataclass
52-
class WiredTableOutput:
53-
pred_html: Optional[str] = None
54-
cell_bboxes: Optional[np.ndarray] = None
55-
logic_points: Optional[np.ndarray] = None
56-
elapse: Optional[float] = None
27+
cur_dir = Path(__file__).resolve().parent
28+
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
29+
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"
5730

5831

5932
class WiredTableRecognition:
60-
def __init__(self, config: WiredTableInput):
61-
self.model_type = config.model_type
62-
if self.model_type not in KEY_TO_MODEL_URL:
63-
model_list = ",".join(KEY_TO_MODEL_URL)
64-
raise ValueError(
65-
f"{self.model_type} is not supported. The currently supported models are {model_list}."
66-
)
67-
68-
config.model_path = self.get_model_path(config.model_type, config.model_path)
69-
if self.model_type == ModelType.CYCLE_CENTER_NET.value:
70-
self.table_structure = TSRCycleCenterNet(asdict(config))
71-
else:
72-
self.table_structure = TSRUnet(asdict(config))
73-
33+
def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
7434
self.load_img = LoadImage()
35+
if version == "v2":
36+
model_path = table_model_path if table_model_path else default_model_path_v2
37+
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
38+
else:
39+
model_path = table_model_path if table_model_path else default_model_path
40+
self.table_line_rec = TableLineRecognition(str(model_path))
7541

7642
self.table_recover = TableRecover()
7743

44+
try:
45+
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
46+
except ModuleNotFoundError:
47+
self.ocr = None
48+
7849
def __call__(
7950
self,
8051
img: InputType,
8152
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
8253
**kwargs,
83-
) -> WiredTableOutput:
54+
) -> Tuple[str, float, Any, Any, Any]:
55+
if self.ocr is None and ocr_result is None:
56+
raise ValueError(
57+
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
58+
)
59+
8460
s = time.perf_counter()
61+
rec_again = True
8562
need_ocr = True
8663
col_threshold = 15
8764
row_threshold = 10
8865
if kwargs:
66+
rec_again = kwargs.get("rec_again", True)
8967
need_ocr = kwargs.get("need_ocr", True)
9068
col_threshold = kwargs.get("col_threshold", 15)
9169
row_threshold = kwargs.get("row_threshold", 10)
9270
img = self.load_img(img)
93-
polygons, rotated_polygons = self.table_structure(img, **kwargs)
71+
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
9472
if polygons is None:
9573
logging.warning("polygons is None.")
96-
return WiredTableOutput("", None, None, 0.0)
74+
return "", 0.0, None, None, None
9775

9876
try:
9977
table_res, logi_points = self.table_recover(
@@ -108,34 +86,52 @@ def __call__(
10886
sorted_polygons, idx_list = sorted_ocr_boxes(
10987
[box_4_2_poly_to_box_4_1(box) for box in polygons]
11088
)
111-
return WiredTableOutput(
89+
return (
11290
"",
91+
time.perf_counter() - s,
11392
sorted_polygons,
11493
logi_points[idx_list],
115-
time.perf_counter() - s,
94+
[],
11695
)
96+
if ocr_result is None and need_ocr:
97+
ocr_result, _ = self.ocr(img)
11798
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
11899
# 如果有识别框没有ocr结果,直接进行rec补充
119-
cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
100+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
120101
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
121-
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
102+
t_rec_ocr_list_dict = self.transform_res(cell_box_det_map, polygons, logi_points)
103+
# 第一行或者第一列为空时,调整代码
104+
#adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)
105+
adjust_dict = self.process_ocr_result(t_rec_ocr_list_dict)
122106
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
123-
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
107+
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list_dict)
124108
# cell_box_map =
125109
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
126110
cell_box_det_map = {
127111
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
128112
for i, t_box_ocr in enumerate(t_rec_ocr_list)
129113
}
130-
pred_html = plot_html_table(logi_points, cell_box_det_map)
131-
polygons = np.array(polygons).reshape(-1, 8)
132-
logi_points = np.array(logi_points)
133-
elapse = time.perf_counter() - s
114+
table_str = plot_html_table(logi_points, cell_box_det_map)
115+
ocr_boxes_res = [
116+
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
117+
]
118+
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
119+
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
120+
sorted_logi_points = logi_points
121+
table_elapse = time.perf_counter() - s
134122

135123
except Exception:
136124
logging.warning(traceback.format_exc())
137-
return WiredTableOutput("", None, None, 0.0)
138-
return WiredTableOutput(pred_html, polygons, logi_points, elapse)
125+
return "", 0.0, None, None, None
126+
return (
127+
table_str,
128+
table_elapse,
129+
sorted_polygons,
130+
sorted_logi_points,
131+
sorted_ocr_boxes_res,
132+
adjust_dict
133+
134+
)
139135

140136
def transform_res(
141137
self,
@@ -166,6 +162,102 @@ def transform_res(
166162
res.append(dict_res)
167163
return res
168164

165+
def process_ocr_result(self, ocr_result):
166+
# 删除第一行的字典,并调整其余字典的行数
167+
first_row_empty = [entry for entry in ocr_result if
168+
entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0 and entry['t_ocr_res'][0][
169+
1] == '']
170+
171+
if len(first_row_empty) == len(
172+
[entry for entry in ocr_result if entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0]):
173+
# 如果第一行的所有单元格都为空,删除第一行
174+
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][0] != 0 or entry['t_logic_box'][1] != 0]
175+
# 调整剩余字典的行数
176+
for entry in ocr_result:
177+
entry['t_logic_box'][0] -= 1
178+
entry['t_logic_box'][1] -= 1
179+
180+
# 删除第一列的字典,并调整其余字典的列数
181+
first_col_empty = [entry for entry in ocr_result if
182+
entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0 and entry['t_ocr_res'][0][
183+
1] == '']
184+
185+
if len(first_col_empty) == len(
186+
[entry for entry in ocr_result if entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0]):
187+
# 如果第一列的所有单元格都为空,删除第一列
188+
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0]
189+
# 调整剩余字典的列数
190+
for entry in ocr_result:
191+
entry['t_logic_box'][2] -= 1
192+
entry['t_logic_box'][3] -= 1
193+
194+
return ocr_result
195+
196+
def adjust_table_cells(self, t_rec_ocr_list_dict):
197+
"""
198+
调整表格单元格,去掉第一行和/或第一列的单元格,
199+
并更新剩余单元格的行列起始和结束位置。
200+
201+
参数:
202+
t_rec_ocr_list_dict (list): 原始表格单元格识别结果,格式为
203+
[
204+
{
205+
"t_box": [xmin, ymin, xmax, ymax],
206+
"t_logic_box": [row_start, row_end, col_start, col_end],
207+
"t_ocr_res": [[box, text], ...]
208+
},
209+
...
210+
]
211+
212+
返回:
213+
list: 调整后的表格单元格识别结果,格式与输入相同。
214+
"""
215+
# 新的结果列表
216+
adjusted_result = []
217+
218+
# 记录是否第一行和第一列的单元格已被删除
219+
remove_first_row = False
220+
remove_first_col = False
221+
222+
# 检查并移除第一行
223+
if all(cell and not cell[1] for cell in t_rec_ocr_list_dict[0].get("t_ocr_res", [])):
224+
remove_first_row = True
225+
226+
# 检查并移除第一列
227+
if all(row.get("t_ocr_res") and not row["t_ocr_res"][0][1] for row in t_rec_ocr_list_dict):
228+
remove_first_col = True
229+
230+
# 遍历原始结果进行调整
231+
for i, row in enumerate(t_rec_ocr_list_dict):
232+
adjusted_row = []
233+
234+
# 如果是第一行并且需要删除,跳过这行
235+
if remove_first_row and i == 0:
236+
continue
237+
238+
for j, cell in enumerate(row.get("t_ocr_res", [])):
239+
# 如果是第一列并且需要删除,跳过这一列
240+
if remove_first_col and j == 0:
241+
continue
242+
243+
# 更新当前单元格的逻辑位置
244+
adjusted_cell = {
245+
"t_box": row.get("t_box"),
246+
"t_logic_box": [
247+
row["t_logic_box"][0] - 1 if i > 0 else row["t_logic_box"][0],
248+
row["t_logic_box"][1] - 1 if i > 0 else row["t_logic_box"][1],
249+
row["t_logic_box"][2] - 1 if j > 0 else row["t_logic_box"][2],
250+
row["t_logic_box"][3] - 1 if j > 0 else row["t_logic_box"][3]
251+
],
252+
"t_ocr_res": cell
253+
}
254+
adjusted_row.append(adjusted_cell)
255+
256+
if adjusted_row:
257+
adjusted_result.append(adjusted_row)
258+
259+
return adjusted_result
260+
169261
def sort_and_gather_ocr_res(self, res):
170262
for i, dict_res in enumerate(res):
171263
_, sorted_idx = sorted_ocr_boxes(
@@ -177,19 +269,30 @@ def sort_and_gather_ocr_res(self, res):
177269
)
178270
return res
179271

180-
def fill_blank_rec(
272+
def re_rec(
181273
self,
182274
img: np.ndarray,
183275
sorted_polygons: np.ndarray,
184276
cell_box_map: Dict[int, List[str]],
277+
rec_again=True,
185278
) -> Dict[int, List[Any]]:
186279
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
187280
for i in range(sorted_polygons.shape[0]):
188281
if cell_box_map.get(i):
189282
continue
283+
if not rec_again:
284+
box = sorted_polygons[i]
285+
cell_box_map[i] = [[box, "", 1]]
286+
continue
287+
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
288+
pad_img = cv2.copyMakeBorder(
289+
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
290+
)
291+
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
190292
box = sorted_polygons[i]
191-
cell_box_map[i] = [[box, "", 1]]
192-
continue
293+
text = [rec[0] for rec in rec_res]
294+
scores = [rec[1] for rec in rec_res]
295+
cell_box_map[i] = [[box, "".join(text), min(scores)]]
193296
return cell_box_map
194297

195298
def re_rec_high_precise(
@@ -222,46 +325,24 @@ def re_rec_high_precise(
222325
]
223326
return cell_box_map
224327

225-
@staticmethod
226-
def get_model_path(
227-
model_type: str, model_path: Union[str, Path, None]
228-
) -> Union[str, Dict[str, str]]:
229-
if model_path is not None:
230-
return model_path
231-
232-
model_url = KEY_TO_MODEL_URL.get(model_type, None)
233-
if isinstance(model_url, str):
234-
model_path = DownloadModel.download(model_url)
235-
return model_path
236-
237-
if isinstance(model_url, dict):
238-
model_paths = {}
239-
for k, url in model_url.items():
240-
model_paths[k] = DownloadModel.download(
241-
url, save_model_name=f"{model_type}_{Path(url).name}"
242-
)
243-
return model_paths
244-
245-
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
246-
247328

248329
def main():
249330
parser = argparse.ArgumentParser()
250331
parser.add_argument("-img", "--img_path", type=str, required=True)
251332
args = parser.parse_args()
252333

253334
try:
254-
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
335+
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
255336
except ModuleNotFoundError as exc:
256337
raise ModuleNotFoundError(
257-
"Please install the rapidocr by pip install rapidocr."
338+
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
258339
) from exc
259-
input_args = WiredTableInput()
260-
table_rec = WiredTableRecognition(input_args)
340+
341+
table_rec = WiredTableRecognition()
261342
ocr_result, _ = ocr_engine(args.img_path)
262-
table_results = table_rec(args.img_path, ocr_result)
263-
print(table_results.pred_html)
264-
print(f"cost: {table_results.elapse:.5f}")
343+
table_str, elapse = table_rec(args.img_path, ocr_result)
344+
print(table_str)
345+
print(f"cost: {elapse:.5f}")
265346

266347

267348
if __name__ == "__main__":

0 commit comments

Comments
 (0)