-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatasync.py
190 lines (159 loc) · 6.07 KB
/
datasync.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# -*- coding: utf-8 -*-
import io
import psycopg2
import requests
import shutil
import tempfile
import csv #unicodecsv
import time
import zipfile
from io import StringIO
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ['DJANGO_SETTINGS_MODULE'] = 'api.settings'
import django
django.setup()
from django.db import connections as CONNECTIONS
from conf.settings import DBTABLES, DBSCHEMA
ZIPURL = 'http://www.peatus.ee/gtfs/gtfs.zip'
# FIXME: ZIPURL should be in conf.settings aswell!
class FilteredCSVFile(csv.DictReader, object):
"""Local helper for reading only specified columns from a csv file.
It's assumed that row number 1 is the header row.
"""
def __init__(
self, csvfile, fieldnames=None, restkey=None, restval=None,
dialect='excel', *args, **kwargs):
self._header = self.get_csv_header(csvfile)
super(FilteredCSVFile, self).__init__(
csvfile, self._header , restkey, restval,
dialect, *args, **kwargs)
self._fieldnames = fieldnames
def get_csv_header(self, fp):
return fp.readline().strip('\n').split(',')
def cleanup(self, obj):
if obj == None or obj == "":
obj = '\\N'
if isinstance(obj, str):
obj = obj.replace('\t', ' ')
if ',' in obj:
obj = '"%s"' % obj
return obj
def next(self):
row = dict(zip(self._header, next(self.reader)))
return '\t'.join(['%s' % self.cleanup(row[k]) for k in self._fieldnames])
def readline(self):
return self.next()
def read(self):
o = io.StringIO()
try:
while True:
row = self.next()
o.write(row)
o.write(u'\n')
except StopIteration as si:
pass
return o.getvalue()
def download_zip(url, to_path):
"""Download zipfile from url and extract it to to_path.
Returns the path of extraction.
"""
filename = url.split('/')[-1]
r = requests.get(url)
r.raise_for_status()
content = io.BytesIO(r.content)
with zipfile.ZipFile(content) as z:
z.extractall(to_path)
return to_path
def get_csv_header(filepath):
"""Retuns csv file's header row."""
with open(filepath) as n:
return n.readline().strip('\n').split(',')
def _db_check_table(cursor, dbschema, dbtable):
"""Checks input table's existance in the database.
Returns a list with table's column names.
"""
tab = '%s.%s' % (dbschema, dbtable)
sql = "select array_agg(attname) from pg_attribute " \
"where attrelid=%s::regclass and not attisdropped and attnum > 0"
params = (tab,)
cursor.execute(sql, params)
return cursor.fetchone()[0]
def _fs_check_csv(path, filename, ext='txt'):
"""Checks if the input csv file really exists.
Returns a tuple of csv absolute filepath, and headers.
"""
filename = '%s.%s' % (filename, ext)
fp = os.path.join(path, filename)
assert os.path.exists(fp)
return fp, get_csv_header(fp)
def _get_insert_cols(db_cols, fp_cols, dbschema, tablename):
"""Returns intersection of input column names.
Use this to figure out which columns need to be read from the csv file.
"""
cols = list(set(db_cols).intersection(fp_cols))
assert len(cols) > 0, "%s.%s and %s.csv do not share any columns" % (
dbschema, dbtable, dbtable)
return cols
def _db_prepare_truncate(tableschema, tablename):
"""Prepare a truncate statement for a table in the database.
@FIXME: as this is prone to injection check whether the tablename
mentioned in args really exists.
"""
sql = """truncate table %(sch)s.%(tab)s cascade"""
params = dict(sch=tableschema, tab=tablename)
return sql % params
#{main
def run():
"""Run data download and database sync operations"""
try:
# go get all csv files extracted at to_path.
# local
# to_path = 'tmp'
# the real thing
to_path = download_zip(ZIPURL, tempfile.mkdtemp(prefix='eoy_'))
print(to_path)
# loop through required files and look for a matching table
# in the database
# if found truncate it and insert new rows from the csv file
# if table not found, raise exception
# if exception, then rollback and stop whatever was going on
# all database commands run in a single transaction
c = CONNECTIONS['sync']
with c.cursor() as cursor:
# loop through the list of tables specified at
# conf.settings.DBTABLES
for dbtable in DBTABLES:
# check if table exists in db and get it's columns
db_cols = _db_check_table(cursor, DBSCHEMA, dbtable)
# check if file present and get csv header
fp, fp_cols = _fs_check_csv(to_path, dbtable)
print ('%s.%s' %(DBSCHEMA, dbtable))
# get intersection of db_cols and fp_cols (i.e cols that
# are present in both)
cols = _get_insert_cols(db_cols, fp_cols, DBSCHEMA, dbtable)
# truncate old data,
st_trunc = _db_prepare_truncate(DBSCHEMA, dbtable)
cursor.execute(st_trunc)
# and fill anew ...
with open(fp, encoding='utf-8') as f:
fcsv = FilteredCSVFile(f, fieldnames=cols, quotechar='"')
tab = '%s.%s' % (DBSCHEMA, dbtable)
cursor.copy_from(io.StringIO(fcsv.read()), tab, sep='\t', columns=cols)
print(cursor.rowcount)
print('done %s' % fp)
except:
raise
# FIXME: This is the place for calling data prep functions in the database.
# keep the file for now...
#shutil.rmtree(to_path)
def postprocess():
with open('preprocess-all.sql') as f:
statements = f.read()
c = CONNECTIONS['sync']
with c.cursor() as cursor:
for statement in statements.split(';'):
c.execute(statement)
if __name__ == '__main__':
run()
#pass