type
status
date
slug
summary
tags
category
icon
password
paper | |
github | |
个人博客位置 |
1 Main Idea
SLANet
是一个轻量级的表格结构识别模型。它将表格结构识别任务建模为序列标注,以自回归的方式预测表格的html序列和单元格位置。2 Method
2.1 整体架构
SLANet包含3个组件:backbone,neck,head。

Neck
:SLANet
采用CSP-PAN
将backbone提取的多尺度特征进行融合。输入多尺度特征输出融合后的特征尺度为
Head
:SLANet
采用GRU
+Attention
的形式来递归生成表格的html序列和单元格位置。

import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class AttentionGRUCell(nn.Module): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() self.i2h = nn.Linear(input_size, hidden_size, bias=False) self.h2h = nn.Linear(hidden_size, hidden_size) self.score = nn.Linear(hidden_size, 1, bias=False) self.rnn = nn.GRUCell( input_size=input_size + num_embeddings, hidden_size=hidden_size) self.hidden_size = hidden_size def forward(self, prev_hidden, batch_H, char_onehots): batch_H_proj = self.i2h(batch_H) prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1) res = torch.add(batch_H_proj, prev_hidden_proj) res = torch.tanh(res) e = self.score(res) alpha = F.softmax(e, dim=1) alpha = alpha.permute(0, 2, 1) # alpha = torch.transpose(alpha, [0, 2, 1]) context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1) concat_context = torch.concat([context, char_onehots], 1) cur_hidden = self.rnn(concat_context, prev_hidden) return (cur_hidden, cur_hidden), alpha class SLAHead(nn.Module): def __init__(self, in_channels, hidden_size, out_channels=30, max_text_length=500, loc_reg_num=4, fc_decay=0.0, **kwargs): """ @param in_channels: input shape @param hidden_size: hidden_size for RNN and Embedding @param out_channels: num_classes to rec @param max_text_length: max text pred """ super().__init__() in_channels = in_channels[-1] self.hidden_size = hidden_size self.max_text_length = max_text_length self.emb = self._char_to_onehot self.num_embeddings = out_channels self.loc_reg_num = loc_reg_num # structure self.structure_attention_cell = AttentionGRUCell( in_channels, hidden_size, self.num_embeddings) self.structure_generator = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.Linear(hidden_size, out_channels) ) # loc self.loc_generator = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.Linear(self.hidden_size, loc_reg_num), nn.Sigmoid()) def forward(self, inputs, targets=None): fea = inputs[-1] batch_size = fea.shape[0] # reshape fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], -1]) # fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) fea = fea.permute(0, 2, 1) hidden = torch.zeros((batch_size, self.hidden_size)) structure_preds = torch.zeros( (batch_size, self.max_text_length + 1, self.num_embeddings)) loc_preds = torch.zeros( (batch_size, self.max_text_length + 1, self.loc_reg_num)) structure_preds.requires_grad = False loc_preds.requires_grad = False if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): hidden, structure_step, loc_step = self._decode(structure[:, i], fea, hidden) structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step else: pre_chars = torch.zeros(batch_size, dtype=torch.int64) max_text_length = torch.tensor(self.max_text_length) # for export loc_step, structure_step = None, None for i in range(max_text_length + 1): hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden) pre_chars = structure_step.argmax(dim=1) structure_preds[:, i, :] = structure_step loc_preds[:, i, :] = loc_step if not self.training: structure_preds = F.softmax(structure_preds, dim=-1) return {'structure_probs': structure_preds, 'loc_preds': loc_preds} def _decode(self, pre_chars, features, hidden): """ Predict table label and coordinates for each step @param pre_chars: Table label in previous step @param features: @param hidden: hidden status in previous step @return: """ emb_feature = self.emb(pre_chars) # output shape is b * self.hidden_size (output, hidden), alpha = self.structure_attention_cell( hidden, features, emb_feature) # structure structure_step = self.structure_generator(output) # loc loc_step = self.loc_generator(output) return hidden, structure_step, loc_step def _char_to_onehot(self, input_char): input_ont_hot = F.one_hot(input_char, self.num_embeddings) return input_ont_hot
2.2 SLANet损失函数
从模型架构可见,SLANet有两个预测分支,分别预测表格的HTML序列和对应单元格的位置。HTML序列以分类任务的交叉熵损失进行监督,单元格位置以回归任务的smooth-L1损失训练监督。二者联合训练,总loss为二者的加权叠加。
注意:仅计算HTML序列为单元格处的单元格回归损失。代码中的实现是,仅当序列预测为'<td>', '<td', '<td></td>'
是计算回归损失。
2.3 SLANet
的标签体系介绍
SLANet的标签体系主要参考TableMaster。英文表格标签体系如下表所示。值得说明的点
- 引入
<td></td>
来模拟空白单元格
- 引入
<td
用于和rowspan,colpan来拼接,便于处理合并单元格的情况。
label id | label name | 说明 |
0 | sos | sequence起始符 |
1 | <thead> | 表格标题标记(始) |
2 | <tr> | 表格行标记(始) |
3 | </td> | 表格单元格标记(终) |
4 | </tr> | 表格行标记(始) |
5 | </thead> | 格标题标记(终) |
6 | <tbody> | 组合 HTML 表格的主体内容(始) |
7 | </tbody> | 组合 HTML 表格的主体内容(终) |
8 | <td | 表格单元格标记(始),少右尖括号是为了拼接跨行,或跨列 |
9 | colspan=\"5\ | 单元格跨5列 |
10 | > | ㅤ |
11 | colspan=\"2\ | 单元格跨2列 |
12 | colspan=\"3\ | 单元格跨3列 |
13 | rowspan=\"2\ | 单元格跨2行 |
14 | colspan=\"4\ | 单元格跨4列 |
15 | colspan=\"6\ | 单元格跨6列 |
16 | rowspan=\"3\ | 单元格跨3行 |
17 | colspan=\"9\ | 单元格跨9列 |
18 | colspan=\"10\ | 单元格跨10列 |
19 | colspan=\"7\ | 单元格跨7列 |
20 | rowspan=\"4\ | 单元格跨4行 |
21 | rowspan=\"5\ | 单元格跨5行 |
22 | rowspan=\"9\ | 单元格跨9行 |
23 | colspan=\"8\ | 单元格跨8列 |
24 | rowspan=\"8\ | 单元格跨8行 |
25 | rowspan=\"6\ | 单元格跨6行 |
26 | rowspan=\"7\ | 单元格跨7行 |
27 | rowspan=\"10\ | 单元格跨10行 |
28 | <td></td> | 空白单元格 |
29 | eos | sequence终止符 |
2.4 如何合并OCR的结果做表格结构还原
从前文可知,SLANet的识别结果只有表格的HTML序列信息,和单元格的坐标信息。完整的表格结构还原还需要单元格内部的文本信息。因此需要额外引入一个OCR引擎来识别表格内部的文本位置和文本标记。最后将他们进行匹配。
举个🌰
下图为SLANet的识别结果(包含表格的HTML序列和单元格坐标)

下图为OCR引擎的识别结果(包含表格中的文本框坐标,识别结果)

最后通过匹配算法将两个结果进行结合,最终得到完整的表格结果还原结果。

匹配算法的核心思路:将OCR的文本框分配给单元格。分配机制:先分配
iou
大的,剩余的就近分配。具体可见代码。3 小结
SLANet
是一个轻量级的表格结构识别模型,将表格结构识别任务建模为序列标注任务,通过自回归方式预测表格的HTML序列和单元格位置。模型由三部分组成:Backbone采用轻量级
PP-LCNet
提取多尺度特征。Neck使用 CSP-PAN
融合多尺度特征。Head:基于 GRU
和 Attention
机制递归生成HTML序列和单元格位置。损失函数结合了HTML序列的分类损失(交叉熵)和单元格位置的回归损失(Smooth-L1),联合训练优化。标签体系参考
TableMaster
,引入特殊标签(如 <td></td>
和 <td
)处理空白单元格和合并单元格情况。最后介绍了完整的表格结构还原链路。
- 作者:莫叶何竹🍀
- 链接:http://www.myhz0606.com/article/slanet
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章