This repository has been archived by the owner on Oct 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 14
/
Rakefile
151 lines (125 loc) · 3.97 KB
/
Rakefile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
require "bundler/gem_tasks"
require "rake/testtask"
task default: :test
Rake::TestTask.new do |t|
t.libs << "test"
t.pattern = "test/**/*_test.rb"
t.warning = false
end
# -- TODO: put everything below somewhere better --
# based on ActiveSupport underscore
def underscore(str)
str.gsub(/([A-Z]+)([A-Z][a-z])/,'\1_\2').gsub(/([a-z\d])([A-Z])/,'\1_\2').downcase
end
def arg_name(name)
# start and stop choosen as they are used for some operations
case name
when "begin"
"start"
when "end"
"stop"
else
name
end
end
def read_op_def
require "tensorflow"
# TODO pull these into project?
path = "#{ENV["HOME"]}/forks/tensorflow"
$:.push(path)
require "tensorflow/core/framework/op_def_pb"
buffer = TensorFlow::FFI.TF_GetAllOpList
encoded = buffer[:data].read_bytes(buffer[:length])
Tensorflow::OpList.decode(encoded).op.sort_by(&:name)
end
task :generate_ops do
defs = []
read_op_def.each do |op|
input_names = op.input_arg.map { |v| arg_name(v.name) }
options = op.attr.map { |v| arg_name(v.name) }.reject { |v| v[0] == v[0].upcase }
if op.name[0] != "_"
def_name = underscore(op.name).gsub(/2_d/, "2d").gsub(/3_d/, "3d")
def_options_str = (input_names + options).map { |v| ", #{v}: nil" }.join
execute_options_str = options.map { |v| ", #{v}: #{v}" }.join
defs << %! def #{def_name}(#{def_options_str})
Utils.execute("#{op.name}", [#{input_names.join(", ")}]#{execute_options_str})
end!
end
end
contents = %!# Generated by `rake generate_ops`
module TensorFlow
module RawOps
class << self
#{defs.join("\n\n")}
end
end
end
!
contents = contents.gsub("()", "").gsub("(, ", "(")
# puts contents
File.write("lib/tensorflow/raw_ops.rb", contents)
end
task :seed_ops do
require "nokogiri"
require "open-uri"
cached_path = "/tmp/ops.html"
unless File.exist?(cached_path)
url = "https://www.tensorflow.org/versions/r2.0/api_docs/python"
puts "Downloading #{url}"
File.write(cached_path, URI.parse(url).read)
end
ops = []
doc = Nokogiri::HTML(File.read(cached_path))
doc.css("a").each do |node|
text = node.text.strip
if text.start_with?("tf.") && text == text.downcase && !text.include?(".compat.")
ops << text
end
end
# op defs
op_def = read_op_def.map { |op| [underscore(op.name).gsub(/2_d/, "2d").gsub(/3_d/, "3d"), op] }.to_h
# top level ops
tf_ops = ops.select { |op| op.count(".") == 1 }.map { |v| v.sub("tf.", "") }
# determine modules
modules = ops.select { |op| op.count(".") == 2 }.map { |v| v.split(".")[1] }.uniq
modules.each do |mod|
next unless ["audio", "bitwise", "image", "io", "linalg", "strings"].include?(mod)
mod_ops = ops.select { |op| op.start_with?("tf.#{mod}.") && op.count(".") == 2 }.map { |v| v.sub("tf.#{mod}.", "") }
mod_class = mod.capitalize
next if mod_ops.include?("experimental")
# puts mod
# p mod_ops
defs = []
mod_ops.each do |def_name|
op = op_def[def_name]
if !op
defs << %! # def #{def_name}
# end!
else
input_names = op.input_arg.map { |v| arg_name(v.name) }
options = op.attr.map { |v| arg_name(v.name) }.reject { |v| v[0] == v[0].upcase }
input_names_str = input_names.join(", ")
def_options_str = options.map { |v| ", #{v}: nil" }.join
raw_options_str = (input_names + options).map { |v| "#{v}: #{v}" }.join(", ")
defs << %! def #{def_name}(#{input_names_str}#{def_options_str})
RawOps.#{def_name}(#{raw_options_str})
end!
end
end
contents = %!module TensorFlow
module #{mod_class}
class << self
#{defs.join("\n\n")}
end
end
end
!
contents = contents.gsub("()", "").gsub("(, ", "(")
# puts contents
File.write("lib/tensorflow/#{mod}.rb", contents)
delegate_mod_ops = tf_ops & mod_ops
if delegate_mod_ops.any?
puts "def_delegators #{mod_class}, #{delegate_mod_ops.map { |v| v.to_sym.inspect }.join(", ")}"
end
end
end