-
Notifications
You must be signed in to change notification settings - Fork 160
/
Shortest distance between two nodes.cpp
158 lines (107 loc) · 3.4 KB
/
Shortest distance between two nodes.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*
Given two nodes of a Binary Tree, calculate the shortest distance between the two nodes.
*/
/*
solutions:
Find the lowest common ancestor of the given nodes.
Find the distance between the lowest common ancestor and each given node separately.
Add the distances obtained above.
O(n) time, O(n) space
*/
#include<iostream>
#include<list>
using namespace std;
struct Node {
int val;
Node *plft;
Node *prgt;
Node(int v) : val(v), plft(NULL), prgt(NULL) {}
};
bool GetNodePath(Node *root, Node *node, list<Node*>& path) {
if (root == node) return true;
path.push_back(root);
bool found = false;
if (root->plft != NULL)
found = getNodePath(root->plft, node, path);
if (!found && root->prgt)
found = getNodePath(root->prgt, node, path);
if (!found)
path.pop_back();
return found;
}
Node* LastCommonNode(const list<Node*>& path1, const list<Node*>& path2) {
list<Node*>::const_iterator iterator1 = path1.begin();
list<Node*>::const_iterator iterator2 = path2.begin();
Node *last = NULL;
while (iterator1 != path1.end() && iterator2 != path2.end()) {
if (*iterator1 == *iterator2)
last = *iterator1;
iterator1++;
iterator2++;
}
return last;
}
Node *LastCommonAncestor(Node* root, Node* node1, Node* node2) {
if(root == NULL || node1 == NULL || node2 == NULL) return NULL;
list<Node*> path1;
GetNodePath(root, node1, path1);
list<Node*> path2;
GetNodePath(root, node2, path2);
return LastCommonNode(path1, path2);
}
int Height(Node *lca, Node *node,bool &found) {
int lheight = 0, rheight = 0;
if (lca) {
if (found == false && lca == node) { //lca is node1
found = true;
return 0; //distance is 0
} else if (found == false) {
lheight = Height(lca->plft, node, found);
rheight = 0;
if(found == false) { //node is not in lca's left subtree
rheight = Height(lca->prgt, node, found);
}
if(found == true) { //node is in lca's left or right subtree
return lheight > rheight? 1+lheight : 1+rheight;
} else { //node is not in lca's subtree
return 0;
}
} else {
return 0;
}
} else { // if(lca)
return 0;
}
}
int ShortestDistance(Node *node1, Node* node2, Node *lca) {
if (lca) {
bool found = false;
int dist1 = Height(lca, node1, found);
cout<<"Distance of "<<node1->val<<": "<<dist1<<endl;
found = false;
int dist2 = Height(lca,node2,found);
cout<<"Distance of "<<node2->val<<": "<<dist2<<endl;
return dist1 + dist2;
} else {
return 0;
}
}
int main() {
Node *root = new Node(1);
root->plft = new Node(2);
root->prgt = new Node(3);
root->plft->plft = new Node(4);
root->plft->prgt = new Node(5);
root->prgt->plft = new Node(6);
root->prgt->prgt = new Node(7);
root->prgt->prgt->plft = new Node(8);
root->prgt->prgt->prgt = new Node(9);
Node *node1 = root->prgt->prgt->plft;
Node *node2 = root->prgt->plft;
Node *lca = LastCommonAncestor(root, node1, node2);
if (lca) {
cout<<"Least Common Ancestor: "<<lca->val<<endl;
}
cout<<"Total distance: "<<ShortestDistance(node1, node2, lca)<<endl;
return 0;
}