当前位置:网站首页>Kaggle_NBME NLP比赛Baseline详解(2)
Kaggle_NBME NLP比赛Baseline详解(2)
2022-04-22 05:26:00 【Lyttonkeepgoing】
这部分主要讲代码
分四个部分讲解
1.Dataset设计
2.Model设计
3.Loss函数设计
4.一些QA
1.Dataset设计
目标:每次迭代计算的时候要取出部分数据放入模型
实现的方法:继承torch.utils.data.Dataset
具体转化方法:先定义一个prepare_input的方法将文本转化为tensor
再定义一个create label的方法将annotation的标记转化为tensor
我们碰到nlp问题的时候 一定要搞清楚模型的输入和输出是什么 其实现在都在用bert那套框架 只要理解我们模型的输入和输出长什么样子就可以很好的套用模型的框架来解决问题
因为我们的模型不能一次性把我们的数据加入到我们显存里 所以我们就要一个batch的把数据送进去 完成梯度更新 然后一步一步迭代的过程
我们只需要准备两个函数 prepare_input 和create label train data
def prepare_input(cfg, text, feature_text):
# 调用transformers里tokenizer的方法 然后把text和feature_text一起放到tokenizer里面后就会有一个输出
inputs = cfg.tokenizer(text, feature_text, add_special_tokens=True,
max_length=CFG.max_len, padding='max_length',
return_offsets_mapping=False)
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long)
return inputs
# 调用transformers里tokenizer的方法 然后把text和feature_text一起放到tokenizer里面后就会有一个输出 输出的话就是 input_ids(对输入的句子转化成数字的结果),attention_mask(这里面全部是1,表示让模型关注里面的所有词),token_types_ids(用来表示返回的数字编码中,哪些属于第一个句子,哪些属于第二个句子 0属于第一句话 1属于第二句话)这几个k 具体可以看这篇文章(yosemite1998) train_data:就是一个带标记的训练数据
###data loading
train = pd.read_csv('你的路径')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
# ast模块就是帮助Python应用来处理抽象的语法解析的。而该模块下的literal_eval()函数:则会判断需要# 计算的内容计算后是不是合法的python类型,如果是则进行运算,否则就不进行运算。
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('。。。')
def preprocess_features(features):
features.loc[27, 'feature_text'] = 'Last-Pap-swear-1-year-ago'
return features
features = preprocess_features(features)
patient_notes = pd.read_csv('...')
print(train features patient_notes的shape)
输出结果
train_data
可以看到shape 14300行 6列
这6列表示的意思我们上节讲过了
pn_num -- patient note 病例id 病历号
case_num -- case num用来关联起病人patient note的文本描述和对应症状的文本描述
feature_num 每个病症的id
annotation :patient note中体现相关症状的描述 病例中可能对同一个疾病症状存在多处描述
location: annotation 所在的病例中char级别的位置
feature

patient_notes

然后我们要把这些数据merge一下
train和features patient_notes数据merge 起来
train = train.merge(features, on=['feature_num', 'case_num'], how='left')
train = train.merge(patient_notes, on=['pn_um', 'case_num'], how='left')
train['pn_history'] = train['pn_history'].apply(lambda x:x.replace('dad with recent heart attcak', 'dad with recent heart attack')) #这里看到一个错别字 改一下
完整的train

然后会写一些incorrect annotation的修正
照这样写
train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')
然后看dataset的完整代码
def prepare_input(cfg, text, feature_text):
# 调用transformers里tokenizer的方法 然后把text和feature_text一起放到tokenizer里面后就会有一个输出
inputs = cfg.tokenizer(text, # text就相当于paragraph
feature_text, # 就相当于question
add_special_tokens=True,
# 是否要把seq cls加到句子里面
max_length=CFG.max_len,
# 354 如果不够354就要padding
padding='max_length', #
return_offsets_mapping=False)
# return_offsets_mapping 会在create label里面用到
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long)
return inputs
def create_label(cfg, feature, text, annotation_length, location_list, answers):
encoded = cfg.tokenizer(feature, #疾病名称
text, # 病例
add_special_tokens=True,
# 把cls,sep加到输出中
max_length=CFG.max_len,
padding='max_length',
return_offsets_mapping=True )
offset_mapping = encoded['offset_mapping']
squence_ids = encoded.sequence_ids()
ignore_idxes = np.where(np.array(encoded.sequence_ids()) !=1)[0] # 找到不需要算入loss的index
label = np.zeros(len(offset_mapping))
label[ignore_idxes] = -1
for label_ids, location in enumerate(test_location):
loaction = [s.split() for s in location.split(';')]
for loc in location:
answer = test_answers[label_ids]
start_char = int(loc[0])
end_char = start_char+len(answer)
# token start index
token_start_index = 0
while squence_ids[token_start_index] != 1:
token_start_index += 1
# token end index
token_end_index = len(offset_mapping)-1
while squence_ids[tokens_end_index] != 1:
token_end_index -= 1
while token_start_index < len(offset_mapping) and offset_mapping[token_start_index][0] <= start_char:
token_start_index += 1
token_start_index -= 1
while offset_mapping[token_end_index][1] >= end_char:
token_end_index -= 1
token_end_index += 1
label[token_start_index:token_end_index+1] = 1.0
return torch.tensor(label, dtype=torch.float)
class TrainDataset(Dataset):
def __init__(self, cfg, df):
self.cfg = cfg
self.feature_texts = df['features_text'].values
self.pn_historys = df['pn_history'].values
self.annotation_lengths = df['annotation_length'].values
self.locations = df['location'].values
def __len__(self):
return len(self.feature_texts)
def __getitem__(self, item):
inputs = prepare_input(self.cfg,
self.pn_historys[item],
self.feature_texts[item])
label = create_label(self.cfg,
self.pn_historys[item],
self.feature_texts[item]
self.annotation_length[item],
self.location[item])
return inputs, label
token_types_id 0表示第一句话 1表示第二句话
解释一下offsets_mapping: tokenizer 之后 会把一个词拆成两个词或者更多 拆成这种token以后 要 把token映射回去 就是每个token对应的原文中所在的char的位置
(0,0)就是cls or sep token在原文中没有映射
例:CFG.tokenizer.decode(encode['input_ids'][1]) 输出 'Family'
再说一下 sequence_ids
encoded.sequence_ids() 输出为
就是0表示第一句话 1表示第二句话 cls sep padding都为none
通过原文的char级别信息 映射回label上 需要转成和原文输出长度一样的0,1 sequence tensor
0表示没有出现相关疾病的表现
下节继续~
版权声明
本文为[Lyttonkeepgoing]所创,转载请带上原文链接,感谢
https://blog.csdn.net/m0_53292725/article/details/123708878
边栏推荐
- GBase 8s V8.8 SQL 指南:教程-5.2(3)
- Database 13th job transaction management
- [WPF] converter
- Basic concepts of outh2
- Time complexity and space complexity
- One way to disable Google cross domain
- Shenzhen Xishuangbanna
- Batch resolves the IP address of the domain name and opens the web page
- MySQL數據庫第十一次作業-視圖的應用
- [candelastudio edit CDD] - 2.3 - realize the jump between multiple securitylevels of $27 service (UDS diagnosis)
猜你喜欢

Mengxin sees the recruitment of volunteers in the open source community of wedatasphere

【WPF】Popup

Apache poi HSSF operation Excel

13.9.1-PointersOnC-20220421

Studio3t expired activation method / and scripts that reset the use date are not available solution / Studio 3T unlimited activation

The signature of the update package is inconsistent with that of the installed app

Integer源码
![[C] file operation](/img/fd/ddf94b0ffa743f2288f723a263a045.png)
[C] file operation

MySQL数据库第十一次

Social media and fake news in the 2016 election
随机推荐
Reduce the graduation time to before the age of 20, and go to primary school for five years at the age of 5, so as to increase the population
水处理控制系统采用信号隔离器解决因某些现场非电量安装条件的限制问题
基于有限体积法的传热拓扑优化
Sourcetree version backtracking and single change version backtracking
Leetcode 1561. Maximum number of coins you can get
The chain of implicit trust: an analysis of the web third party resources loading
Analysis of database log level: (2022.2.21-2022.2.27)
萌新看过来 | WeDataSphere 开源社区志愿者招募
MySQL数据库第十一次
Cookie injection
Online Tetris with automatic hang-up source code
Send a shutdown command to the LAN computer every 30 seconds
Integer源码
我的创作纪念日
Auto.js 画布设置防锯齿paint.setAntiAlias(true);
Configure security policy on ENSP
Use render texture to display 3D model animation on the UI
IT配电及防火限流式保护器应用及选型
Idea 2021.1 Useful settings
Fundamentals of graphics - depth of field / DOF