type
status
date
slug
summary
tags
category
icon
password
 
paper
github
个人博客位置

1 Main Idea

SLANet 是一个轻量级的表格结构识别模型。它将表格结构识别任务建模为序列标注,以自回归的方式预测表格的html序列和单元格位置。

2 Method

2.1 整体架构

SLANet包含3个组件:backbone,neck,head。
notion image
  • Backbone: SLANet采用轻量级backbonePP-LCNet作为体征提取器,提取多尺度(不同stage)的特征。输入图片尺寸为RB×3×H×W\mathbb{R}^{B \times 3 \times H \times W},经过PP-LCNet 后,输出的多尺度特征:
RB×64×H4×W4RB×128×H8×W8RB×256×H16×W16RB×512×H32×W32\begin{matrix} \mathbb{R}^{B\times 64 \times \frac{H}{4} \times \frac{W}{4}} \\ \mathbb{R}^{B\times 128 \times \frac{H}{8}\times \frac{W}{8}} \\ \mathbb{R}^{B\times 256 \times \frac{H}{16}\times \frac{W}{16}} \\ \mathbb{R}^{B\times 512 \times \frac{H}{32}\times \frac{W}{32}} \end{matrix}
  • Neck : SLANet 采用CSP-PAN 将backbone提取的多尺度特征进行融合。输入多尺度特征[RB×64×H4×W4,RB×128×H8×W8,RB×256×H16×W16,RB×512×H32×W32] [\mathbb{R}^{B\times 64 \times \frac{H}{4} \times \frac{W}{4}} , \mathbb{R}^{B\times 128 \times \frac{H}{8}\times \frac{W}{8}} , \mathbb{R}^{B\times 256 \times \frac{H}{16}\times \frac{W}{16}} , \mathbb{R}^{B\times 512 \times \frac{H}{32}\times \frac{W}{32}}] 输出融合后的特征尺度为R1×96×H32×W32\mathbb{R}^{1\times 96 \times \frac{H}{32}\times \frac{W}{32}}
  • Head : SLANet 采用GRU+Attention的形式来递归生成表格的html序列和单元格位置。
    • notion image
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class AttentionGRUCell(nn.Module): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() self.i2h = nn.Linear(input_size, hidden_size, bias=False) self.h2h = nn.Linear(hidden_size, hidden_size) self.score = nn.Linear(hidden_size, 1, bias=False) self.rnn = nn.GRUCell( input_size=input_size + num_embeddings, hidden_size=hidden_size) self.hidden_size = hidden_size def forward(self, prev_hidden, batch_H, char_onehots): batch_H_proj = self.i2h(batch_H) prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1) res = torch.add(batch_H_proj, prev_hidden_proj) res = torch.tanh(res) e = self.score(res) alpha = F.softmax(e, dim=1) alpha = alpha.permute(0, 2, 1) # alpha = torch.transpose(alpha, [0, 2, 1]) context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1) concat_context = torch.concat([context, char_onehots], 1) cur_hidden = self.rnn(concat_context, prev_hidden) return (cur_hidden, cur_hidden), alpha class SLAHead(nn.Module): def __init__(self, in_channels, hidden_size, out_channels=30, max_text_length=500, loc_reg_num=4, fc_decay=0.0, **kwargs): """ @param in_channels: input shape @param hidden_size: hidden_size for RNN and Embedding @param out_channels: num_classes to rec @param max_text_length: max text pred """ super().__init__() in_channels = in_channels[-1] self.hidden_size = hidden_size self.max_text_length = max_text_length self.emb = self._char_to_onehot self.num_embeddings = out_channels self.loc_reg_num = loc_reg_num # structure self.structure_attention_cell = AttentionGRUCell( in_channels, hidden_size, self.num_embeddings) self.structure_generator = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.Linear(hidden_size, out_channels) ) # loc self.loc_generator = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.Linear(self.hidden_size, loc_reg_num), nn.Sigmoid()) def forward(self, inputs, targets=None): fea = inputs[-1] batch_size = fea.shape[0] # reshape fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], -1]) # fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) fea = fea.permute(0, 2, 1) hidden = torch.zeros((batch_size, self.hidden_size)) structure_preds = torch.zeros( (batch_size, self.max_text_length + 1, self.num_embeddings)) loc_preds = torch.zeros( (batch_size, self.max_text_length + 1, self.loc_reg_num)) structure_preds.requires_grad = False loc_preds.requires_grad = False if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): hidden, structure_step, loc_step = self._decode(structure[:, i], fea, hidden) structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step else: pre_chars = torch.zeros(batch_size, dtype=torch.int64) max_text_length = torch.tensor(self.max_text_length) # for export loc_step, structure_step = None, None for i in range(max_text_length + 1): hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden) pre_chars = structure_step.argmax(dim=1) structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step if not self.training: structure_preds = F.softmax(structure_preds, dim=-1) return {'structure_probs': structure_preds, 'loc_preds': loc_preds} def _decode(self, pre_chars, features, hidden): """ Predict table label and coordinates for each step @param pre_chars: Table label in previous step @param features: @param hidden: hidden status in previous step @return: """ emb_feature = self.emb(pre_chars) # output shape is b * self.hidden_size (output, hidden), alpha = self.structure_attention_cell( hidden, features, emb_feature) # structure structure_step = self.structure_generator(output) # loc loc_step = self.loc_generator(output) return hidden, structure_step, loc_step def _char_to_onehot(self, input_char): input_ont_hot = F.one_hot(input_char, self.num_embeddings) return input_ont_hot
 

2.2 SLANet损失函数

从模型架构可见,SLANet有两个预测分支,分别预测表格的HTML序列和对应单元格的位置。HTML序列以分类任务的交叉熵损失进行监督,单元格位置以回归任务的smooth-L1损失训练监督。二者联合训练,总loss为二者的加权叠加。
注意:仅计算HTML序列为单元格处的单元格回归损失。代码中的实现是,仅当序列预测为 '<td>', '<td', '<td></td>' 是计算回归损失。

2.3 SLANet的标签体系介绍

SLANet的标签体系主要参考TableMaster。英文表格标签体系如下表所示。值得说明的点
  • 引入<td></td> 来模拟空白单元格
  • 引入<td 用于和rowspan,colpan来拼接,便于处理合并单元格的情况。
label id
label name
说明
0
sos
sequence起始符
1
<thead>
表格标题标记(始)
2
<tr>
表格行标记(始)
3
</td>
表格单元格标记(终)
4
</tr>
表格行标记(始)
5
</thead>
格标题标记(终)
6
<tbody>
组合 HTML 表格的主体内容(始)
7
</tbody>
组合 HTML 表格的主体内容(终)
8
<td
表格单元格标记(始),少右尖括号是为了拼接跨行,或跨列
9
colspan=\"5\
单元格跨5列
10
>
11
colspan=\"2\
单元格跨2列
12
colspan=\"3\
单元格跨3列
13
rowspan=\"2\
单元格跨2行
14
colspan=\"4\
单元格跨4列
15
colspan=\"6\
单元格跨6列
16
rowspan=\"3\
单元格跨3行
17
colspan=\"9\
单元格跨9列
18
colspan=\"10\
单元格跨10列
19
colspan=\"7\
单元格跨7列
20
rowspan=\"4\
单元格跨4行
21
rowspan=\"5\
单元格跨5行
22
rowspan=\"9\
单元格跨9行
23
colspan=\"8\
单元格跨8列
24
rowspan=\"8\
单元格跨8行
25
rowspan=\"6\
单元格跨6行
26
rowspan=\"7\
单元格跨7行
27
rowspan=\"10\
单元格跨10行
28
<td></td>
空白单元格
29
eos
sequence终止符

2.4 如何合并OCR的结果做表格结构还原

从前文可知,SLANet的识别结果只有表格的HTML序列信息,和单元格的坐标信息。完整的表格结构还原还需要单元格内部的文本信息。因此需要额外引入一个OCR引擎来识别表格内部的文本位置和文本标记。最后将他们进行匹配。
举个🌰
下图为SLANet的识别结果(包含表格的HTML序列和单元格坐标)
notion image
下图为OCR引擎的识别结果(包含表格中的文本框坐标,识别结果)
notion image
最后通过匹配算法将两个结果进行结合,最终得到完整的表格结果还原结果。
notion image
匹配算法的核心思路:将OCR的文本框分配给单元格。分配机制:先分配iou 大的,剩余的就近分配。具体可见代码。

3 小结

SLANet 是一个轻量级的表格结构识别模型,将表格结构识别任务建模为序列标注任务,通过自回归方式预测表格的HTML序列和单元格位置。
模型由三部分组成:Backbone采用轻量级 PP-LCNet 提取多尺度特征。Neck使用 CSP-PAN 融合多尺度特征。Head:基于 GRU 和 Attention 机制递归生成HTML序列和单元格位置。
损失函数结合了HTML序列的分类损失(交叉熵)和单元格位置的回归损失(Smooth-L1),联合训练优化。标签体系参考 TableMaster,引入特殊标签(如 <td></td> 和 <td)处理空白单元格和合并单元格情况。
最后介绍了完整的表格结构还原链路。
 
Supervised Contrastive LearningFeeling after reading The Little Prince 
Loading...
莫叶何竹🍀
莫叶何竹🍀
非淡泊无以明志,非宁静无以致远
最新发布
多模态模型如何处理任意分辨率输入——Tiling与Packing技术详解
2025-5-24
多模态模型如何处理任意分辨率输入——Tiling与Packing技术详解(part2)
2025-5-24
Attention Free Transformer(AFT)技术小结
2025-4-15
BLIP 小结
2025-4-13
BLIP系列文章小结(BLIP, BLIP-2, InstructBLIP)
2025-4-13
Nougat 深度剖析
2025-3-18
hexo