-
Notifications
You must be signed in to change notification settings - Fork 2
/
code_processing_s.py
408 lines (341 loc) · 11.9 KB
/
code_processing_s.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# data_utils.py
# Processing source data and Preparing the reader.
# -*- coding: utf-8 -*-
import token, tokenize
import ast
import re
import sys
import numpy as np
from StringIO import *
from nltk.tokenize import wordpunct_tokenize
import pickle
import pdb
import codecs
sys.path.append("codenn/src")
from sql.SqlTemplate import *
PATTERN_VAR_EQUAL = re.compile("(\s*[_a-zA-Z][_a-zA-Z0-9]*\s*)(,\s*[_a-zA-Z][_a-zA-Z0-9]*\s*)*=")
PATTERN_VAR_FOR = re.compile("for\s+[_a-zA-Z][_a-zA-Z0-9]*\s*(,\s*[_a-zA-Z][_a-zA-Z0-9]*)*\s+in")
def repair_program_io(code):
""" Removing the special IO signs from the program.
Case1:
In [n]:
( ....:)
and
Out [n]:
Case2:
>>>
...
Args:
code: a string, the code snippet.
Returns:
repaired_code: a string, the repaired code snippet.
code_list: a list of strings, each of which is lines of the original code snippet.
The goal is to maintain all of the original information."""
# reg patterns for case 1
pattern_case1_in = re.compile("In ?\[\d+\]: ?") # flag1
pattern_case1_out = re.compile("Out ?\[\d+\]: ?") # flag2
pattern_case1_cont = re.compile("( )+\.+: ?") # flag3
# reg patterns for case 2
pattern_case2_in = re.compile(">>> ?") # flag4
pattern_case2_cont = re.compile("\.\.\. ?") # flag5
patterns = [pattern_case1_in, pattern_case1_out, pattern_case1_cont,
pattern_case2_in, pattern_case2_cont]
lines = code.split("\n")
lines_flags = [0 for _ in range(len(lines))]
code_list = [] # a list of strings
# match patterns
for line_idx in range(len(lines)):
line = lines[line_idx]
for pattern_idx in range(len(patterns)):
if re.match(patterns[pattern_idx], line):
lines_flags[line_idx] = pattern_idx + 1
break
lines_flags_string = "".join(map(str, lines_flags))
bool_repaired = False
# pdb.set_trace()
# repair
if lines_flags.count(0) == len(lines_flags): # no need to repair
repaired_code = code
code_list = [code]
bool_repaired = True
elif re.match(re.compile("(0*1+3*2*0*)+"), lines_flags_string) or\
re.match(re.compile("(0*4+5*0*)+"), lines_flags_string):
repaired_code = ""
pre_idx = 0
sub_block = ""
if lines_flags[0] == 0:
flag = 0
while(flag == 0):
repaired_code += lines[pre_idx] + "\n"
pre_idx += 1
flag = lines_flags[pre_idx]
sub_block = repaired_code
code_list.append(sub_block.strip())
sub_block = "" # clean
for idx in range(pre_idx, len(lines_flags)):
if lines_flags[idx] != 0:
repaired_code += re.sub(patterns[lines_flags[idx] - 1], "", lines[idx]) + "\n"
# clean sub_block record
if len(sub_block.strip()) and (idx > 0 and lines_flags[idx-1] == 0):
code_list.append(sub_block.strip())
sub_block = ""
sub_block += re.sub(patterns[lines_flags[idx] - 1], "", lines[idx]) + "\n"
else:
if len(sub_block.strip()) and (idx > 0 and lines_flags[idx-1] != 0):
code_list.append(sub_block.strip())
sub_block = ""
sub_block += lines[idx] + "\n"
# avoid missing the last unit
if len(sub_block.strip()):
code_list.append(sub_block.strip())
if len(repaired_code.strip()) != 0:
bool_repaired = True
if not bool_repaired: # not typical, then remove only the 0-flag lines after each Out.
repaired_code = ""
sub_block = ""
bool_after_Out = False
for idx in range(len(lines_flags)):
if lines_flags[idx] != 0:
if lines_flags[idx] == 2:
bool_after_Out = True
else:
bool_after_Out = False
repaired_code += re.sub(patterns[lines_flags[idx] - 1], "", lines[idx]) + "\n"
if len(sub_block.strip()) and (idx > 0 and lines_flags[idx-1] == 0):
code_list.append(sub_block.strip())
sub_block = ""
sub_block += re.sub(patterns[lines_flags[idx] - 1], "", lines[idx]) + "\n"
else:
if not bool_after_Out:
repaired_code += lines[idx] + "\n"
if len(sub_block.strip()) and (idx > 0 and lines_flags[idx-1] != 0):
code_list.append(sub_block.strip())
sub_block = ""
sub_block += lines[idx] + "\n"
return repaired_code, code_list
def get_vars(ast_root):
return sorted({node.id for node in ast.walk(ast_root) if isinstance(node, ast.Name) and not isinstance(node.ctx, ast.Load)})
def get_vars_heuristics(code):
varnames = set()
code_lines = [_ for _ in code.split("\n") if len(_.strip())]
# best effort parsing
start = 0
end = len(code_lines) - 1
bool_success = False
while(not bool_success):
try:
root = ast.parse("\n".join(code_lines[start:end]))
except:
end -= 1
else:
bool_success = True
# print("Best effort parse at: start = %d and end = %d." % (start, end))
varnames = varnames.union(set(get_vars(root)))
# print("Var names from base effort parsing: %s." % str(varnames))
# processing the remaining...
for line in code_lines[end:]:
line = line.strip()
try:
root = ast.parse(line)
except:
# matching PATTERN_VAR_EQUAL
pattern_var_equal_matched = re.match(PATTERN_VAR_EQUAL, line)
if pattern_var_equal_matched:
match = pattern_var_equal_matched.group()[:-1] # remove "="
varnames = varnames.union(set([_.strip() for _ in match.split(",")]))
# matching PATTERN_VAR_FOR
pattern_var_for_matched = re.search(PATTERN_VAR_FOR, line)
if pattern_var_for_matched:
match = pattern_var_for_matched.group()[3:-2] # remove "for" and "in"
varnames = varnames.union(set([_.strip() for _ in match.split(",")]))
else:
varnames = varnames.union(get_vars(root))
# print("varnames: %s" % str(varnames))
return varnames
def tokenize_python_code(code):
bool_failed_var = False
bool_failed_token = False
try:
root = ast.parse(code)
varnames = set(get_vars(root))
except:
repaired_code, _ = repair_program_io(code)
try:
root = ast.parse(repaired_code)
varnames = set(get_vars(root))
except:
# failed_var_qids.add(qid)
bool_failed_var = True
varnames = get_vars_heuristics(code)
tokenized_code = []
def first_trial(_code):
if len(_code) == 0:
return True
try:
g = tokenize.generate_tokens(StringIO(_code).readline)
term = g.next()
except:
return False
else:
return True
bool_first_success = first_trial(code)
while not bool_first_success:
code = code[1:]
bool_first_success = first_trial(code)
g = tokenize.generate_tokens(StringIO(code).readline)
term = g.next()
bool_finished = False
while (not bool_finished):
term_type = term[0]
lineno = term[2][0] - 1
posno = term[3][1] - 1
if token.tok_name[term_type] in {"NUMBER", "STRING", "NEWLINE"}:
tokenized_code.append(token.tok_name[term_type])
elif not token.tok_name[term_type] in {"COMMENT", "ENDMARKER"} and len(term[1].strip()):
candidate = term[1].strip()
if candidate not in varnames:
tokenized_code.append(candidate)
else:
tokenized_code.append("VAR")
# fetch the next term
bool_success_next = False
while (not bool_success_next):
try:
term = g.next()
except StopIteration:
bool_finished = True
break
except:
bool_failed_token = True
print("Failed line: ")
# print sys.exc_info()
# tokenize the error line with wordpunct_tokenizer
code_lines = code.split("\n")
# if lineno <= len(code_lines) - 1:
if lineno > len(code_lines) - 1:
print sys.exc_info()
else:
failed_code_line = code_lines[lineno] # error line
print("Failed code line: %s" % failed_code_line)
if posno < len(failed_code_line) - 1:
print("Failed position: %d" % posno)
failed_code_line = failed_code_line[posno:]
tokenized_failed_code_line = wordpunct_tokenize(failed_code_line) # tokenize the failed line segment
print("wordpunct_tokenizer tokenization: ")
print(tokenized_failed_code_line)
# append to previous tokenizing outputs
tokenized_code += tokenized_failed_code_line
if lineno < len(code_lines) - 1:
code = "\n".join(code_lines[lineno + 1:])
g = tokenize.generate_tokens(StringIO(code).readline)
else:
bool_finished = True
break
else:
bool_success_next = True
return tokenized_code, bool_failed_var, bool_failed_token
def tokenize_sql_code(code, bool_remove_comment=True):
"""
Best parsing for SQL code snippets.
Credit to UW codenn project.
Args:
code: a string, a SQL code snippet.
Returns:
tokens: a list of tokens, where columns and tables are replaced with special token + id.
"""
query = SqlTemplate(code, regex=True)
typedCode = query.parseSql()
tokens = [re.sub('\s+', ' ', x.strip()) for x in typedCode]
if bool_remove_comment:
tokens_remove_comment = []
for token in tokens:
if token[0:2] == "--":
pass
else:
tokens_remove_comment.append(token)
tokens = tokens_remove_comment
return tokens, 0, 0
def tokenize_code_corpus(qid_to_code, pl):
""" Tokenizing a code snippet into a list of tokens.
Numbers/strings are replaced with NUMBER/STRING.
Comments are removed.
(modified: replacing variable names with VAR)"""
failed_token_qids = set() # not tokenizable
failed_var_qids = set() # not parsable to have vars
qid_to_tokenized_code = dict()
count = 0
for qid, code in qid_to_code.items():
count += 1
if count % 1000 == 0:
print count
# unicode --> ascii
code = code.encode("ascii", "ignore").strip()
if len(code) == 0:
tokenized_code = [""]
else:
if pl == "python":
tokenized_code, bool_failed_var, bool_failed_token = tokenize_python_code(code)
elif pl == "sql":
tokenized_code, bool_failed_var, bool_failed_token = tokenize_sql_code(code)
else:
raise Exception("Invalid programming language! (Support python and sql only.)")
if bool_failed_token:
failed_token_qids.add(qid)
print("failed tokenization qid: %s" % str(qid))
if bool_failed_var:
failed_var_qids.add(qid)
sys.stdout.flush() # print info
# save
qid_to_tokenized_code[qid] = tokenized_code
print("Total size: %d. Fails: %d." % (len(qid_to_tokenized_code), len(failed_token_qids)))
return qid_to_tokenized_code, failed_var_qids, failed_token_qids
clean_code = {}
def main():
def test_tokenize_code_corpus():
f = pickle.load(open("sql_code_s.txt", "rb"))
sf = pickle.load(open("s.txt", "rb"))
#print sf['train']
questionf = pickle.load(open("sql_title_s.txt", "rb"))
wf = codecs.open("train_s.txt", "w", 'utf-8')
wfprocess = codecs.open("processPro_s.txt", "w", 'utf-8')
num = 0
for x in f:
hasnl = True
try:
print(questionf[x])
print(f[x])
except:
num += 1
print(x)
hasnl = False
#continue
dic = {"1":f[x]}
ans, _, _ = tokenize_code_corpus(dic, "sql")
wfprocess.write(str(x) + "\n")
if hasnl:
wfprocess.write(questionf[x])
else:
wfprocess.write("None")
wfprocess.write("\n")
wfprocess.write(" ".join(ans["1"]))
wfprocess.write("\n")
clean_code[x] = (" ".join(ans["1"]), x)
for x in sf['train']:
try:
print(x, clean_code[x[2]][1])
wf.write(questionf[x[0]])
wf.write("\n")
wf.write(clean_code[x[2]][0])
wf.write("\n")
except:
print x[2]
'''dic = {"1":f[x]}
ans, _, _ = tokenize_code_corpus(dic, "python")
wf.write(questionf[x[0]])
wf.write("\r\n")
wf.write(" ".join(ans["1"]))
wf.write("\r\n")'''
print(num)
test_tokenize_code_corpus()
if __name__ == "__main__":
main()