forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_utils.cpp
93 lines (88 loc) · 2.87 KB
/
graph_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <torch/csrc/jit/ir/graph_utils.h>
namespace torch {
namespace jit {
TypePtr getTensorType(const at::Tensor& t, bool complete) {
auto r = TensorType::create(t);
if (!complete) {
r = r->dimensionedOnly();
}
return r;
}
TypePtr inferShapeAndTypeForInput(
TypePtr input_type,
Stack::const_iterator& s_iter,
const Stack::const_iterator& s_iter_end,
bool complete) {
if (auto tuple_type = input_type->cast<TupleType>()) {
std::vector<TypePtr> types;
for (const auto& sub_type : tuple_type->containedTypes()) {
TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
types.emplace_back(
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
}
return TupleType::create(types);
} else if (auto list_type = input_type->cast<ListType>()) {
const TypePtr& sub_type = list_type->getElementType();
auto elem_type =
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
return ListType::create(elem_type);
} else if (auto tensor_type = input_type->cast<TensorType>()) {
auto type = getTensorType(s_iter->toTensor(), complete);
s_iter++;
return type;
} else if (auto optional_type = input_type->cast<OptionalType>()) {
const TypePtr& sub_type = optional_type->getElementType();
auto elem_type =
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
return OptionalType::create(elem_type);
} else {
// Primitive type, keep as is.
s_iter++;
return input_type;
}
}
void setInputTensorTypes(
Graph& g,
const Stack& stack,
bool complete,
const std::vector<int>& param_count_list) {
at::ArrayRef<Value*> input_values = g.inputs();
auto s_iter = stack.begin();
size_t list_idx = 0;
if (!param_count_list.empty()) {
TORCH_INTERNAL_ASSERT(
input_values.size() == param_count_list.size(),
" input_values:",
input_values.size(),
" vs param_count_list:",
param_count_list.size());
}
for (auto v : input_values) {
// Leave packed param types alone. This is needed for downstream passes
// (like alias analysis) to work properly. This will be unpacked later
// in unpackQuantizedWeights.
if (auto named_type = v->type()->cast<c10::NamedType>()) {
if (auto qualname = named_type->name()) {
if (getCustomClass(qualname->qualifiedName())) {
if (param_count_list.empty()) {
AT_ASSERT(s_iter != stack.end());
s_iter++;
} else {
if (param_count_list[list_idx] > 0) {
AT_ASSERT(s_iter != stack.end());
}
s_iter += param_count_list[list_idx];
}
list_idx++;
continue;
}
}
}
auto type =
inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete);
v->setType(type);
list_idx++;
}
}
} // namespace jit
} // namespace torch