From f722e17d417d36dc53726cc02eb392d9952ddb2b Mon Sep 17 00:00:00 2001 From: Adphi Date: Thu, 9 Sep 2021 12:23:19 +0200 Subject: [PATCH] add SearchAsync Signed-off-by: Adphi --- .gitignore | 1 + client.go | 1 + ldap_test.go | 84 ++++++++++++++++++++++++++++++++++ search.go | 119 ++++++++++++++++++++++++++++++++++++++++++++++-- v3/client.go | 1 + v3/ldap_test.go | 84 ++++++++++++++++++++++++++++++++++ v3/search.go | 119 ++++++++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 401 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index e69de29b..485dee64 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/client.go b/client.go index f0312aff..fa071c7e 100644 --- a/client.go +++ b/client.go @@ -31,5 +31,6 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) } diff --git a/ldap_test.go b/ldap_test.go index 61417fd5..137e9851 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -94,6 +94,90 @@ func TestSearch(t *testing.T) { t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(sr.Entries)) } +func TestSearchAsync(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + var entries []*Entry + responses, err := l.SearchAsync(searchRequest, nil) + if err != nil { + t.Fatal(err) + } + for res := range responses { + if err := res.Err(); err != nil { + t.Error(err) + break + } + if res.Closed() { + break + } + switch res.Type { + case SearchAsyncResponseTypeEntry: + entries = append(entries, res.Entry) + case SearchAsyncResponseTypeReferral: + t.Logf("Received Referral: %s", res.Referral) + case SearchAsyncResponseTypeControl: + t.Logf("Received Control: %s", res.Control) + } + } + t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries)) +} + +func TestSearchAsyncStop(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + var entries []*Entry + done := make(chan struct{}) + responses, err := l.SearchAsync(searchRequest, done) + if err != nil { + t.Fatal(err) + } + close(done) + for res := range responses { + if err := res.Err(); err != nil { + t.Error(err) + break + } + + if res.Closed() { + break + } + switch res.Type { + case SearchAsyncResponseTypeEntry: + entries = append(entries, res.Entry) + case SearchAsyncResponseTypeReferral: + t.Logf("Received Referral: %s", res.Referral) + case SearchAsyncResponseTypeControl: + t.Logf("Received Control: %s", res.Control) + } + } + if len(entries) > 1 { + t.Errorf("Expected 1 entry, got %d", len(entries)) + } + t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries)) +} + func TestSearchStartTLS(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { diff --git a/search.go b/search.go index c174f197..61172f19 100644 --- a/search.go +++ b/search.go @@ -338,6 +338,42 @@ func (s *SearchResult) PrettyPrint(indent int) { } } +// SearchAsyncResponseType describes the SearchAsyncResponse content type +type SearchAsyncResponseType uint8 + +const ( + SearchAsyncResponseTypeNone SearchAsyncResponseType = iota + SearchAsyncResponseTypeEntry + SearchAsyncResponseTypeReferral + SearchAsyncResponseTypeControl +) + +// SearchAsyncResponse holds the server's response message to an async search request +type SearchAsyncResponse struct { + // Type indicates the SearchAsyncResponse type + Type SearchAsyncResponseType + // Entry is the received entry, only set if Type is SearchAsyncResponseTypeEntry + Entry *Entry + // Referral is the received referral, only set if Type is SearchAsyncResponseTypeReferral + Referral string + // Control is the received control, only set if Type is SearchAsyncResponseTypeControl + Control Control + // closed indicates that the request is finished + closed bool + // err holds the encountered error while processing server's response, if any + err error +} + +// Closed returns true if the request is finished +func (r *SearchAsyncResponse) Closed() bool { + return r.closed +} + +// Err returns the encountered error while processing server's response, if any +func (r *SearchAsyncResponse) Err() error { + return r.err +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -405,10 +441,11 @@ func NewSearchRequest( // SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the // search request. All paged LDAP query responses will be buffered and the final result will be returned atomically. // The following four cases are possible given the arguments: -// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size -// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries -// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request -// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size +// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// // A requested pagingSize of 0 is interpreted as no limit by LDAP servers. func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { var pagingControl *ControlPaging @@ -519,6 +556,80 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs the given search request asynchronously, it takes an optional done channel to stop the request. It returns a SearchAsyncResponse channel which will be +// closed when the request finished and an error, not nil if the request to the server failed +func (l *Conn) SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error) { + if done == nil { + done = make(chan struct{}) + } + msgCtx, err := l.doRequest(searchRequest) + if err != nil { + return nil, err + } + responses := make(chan *SearchAsyncResponse) + ch := make(chan *SearchAsyncResponse) + rcv := func() { + for { + packet, err := l.readPacket(msgCtx) + if err != nil { + ch <- &SearchAsyncResponse{closed: true, err: err} + return + } + + switch packet.Children[1].Tag { + case 4: + entry := &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + } + ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeEntry, Entry: entry} + case 5: + err := GetLDAPError(packet) + if err != nil { + ch <- &SearchAsyncResponse{closed: true, err: err} + return + } + var response SearchAsyncResponse + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + responses <- &SearchAsyncResponse{closed: true, err: fmt.Errorf("failed to decode child control: %s", err)} + return + } + response = SearchAsyncResponse{Type: SearchAsyncResponseTypeControl, Control: decodedChild} + } + } + response.closed = true + ch <- &response + return + case 19: + ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeReferral, Referral: packet.Children[1].Children[0].Value.(string)} + } + } + } + go func() { + defer l.finishMessage(msgCtx) + defer close(responses) + go rcv() + for { + select { + case <-done: + responses <- &SearchAsyncResponse{ + closed: true, + } + return + case res := <-ch: + responses <- res + if res.Closed() { + return + } + } + } + }() + return responses, nil +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute { diff --git a/v3/client.go b/v3/client.go index f0312aff..fa071c7e 100644 --- a/v3/client.go +++ b/v3/client.go @@ -31,5 +31,6 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) } diff --git a/v3/ldap_test.go b/v3/ldap_test.go index 61417fd5..137e9851 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -94,6 +94,90 @@ func TestSearch(t *testing.T) { t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(sr.Entries)) } +func TestSearchAsync(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + var entries []*Entry + responses, err := l.SearchAsync(searchRequest, nil) + if err != nil { + t.Fatal(err) + } + for res := range responses { + if err := res.Err(); err != nil { + t.Error(err) + break + } + if res.Closed() { + break + } + switch res.Type { + case SearchAsyncResponseTypeEntry: + entries = append(entries, res.Entry) + case SearchAsyncResponseTypeReferral: + t.Logf("Received Referral: %s", res.Referral) + case SearchAsyncResponseTypeControl: + t.Logf("Received Control: %s", res.Control) + } + } + t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries)) +} + +func TestSearchAsyncStop(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + var entries []*Entry + done := make(chan struct{}) + responses, err := l.SearchAsync(searchRequest, done) + if err != nil { + t.Fatal(err) + } + close(done) + for res := range responses { + if err := res.Err(); err != nil { + t.Error(err) + break + } + + if res.Closed() { + break + } + switch res.Type { + case SearchAsyncResponseTypeEntry: + entries = append(entries, res.Entry) + case SearchAsyncResponseTypeReferral: + t.Logf("Received Referral: %s", res.Referral) + case SearchAsyncResponseTypeControl: + t.Logf("Received Control: %s", res.Control) + } + } + if len(entries) > 1 { + t.Errorf("Expected 1 entry, got %d", len(entries)) + } + t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries)) +} + func TestSearchStartTLS(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { diff --git a/v3/search.go b/v3/search.go index c174f197..61172f19 100644 --- a/v3/search.go +++ b/v3/search.go @@ -338,6 +338,42 @@ func (s *SearchResult) PrettyPrint(indent int) { } } +// SearchAsyncResponseType describes the SearchAsyncResponse content type +type SearchAsyncResponseType uint8 + +const ( + SearchAsyncResponseTypeNone SearchAsyncResponseType = iota + SearchAsyncResponseTypeEntry + SearchAsyncResponseTypeReferral + SearchAsyncResponseTypeControl +) + +// SearchAsyncResponse holds the server's response message to an async search request +type SearchAsyncResponse struct { + // Type indicates the SearchAsyncResponse type + Type SearchAsyncResponseType + // Entry is the received entry, only set if Type is SearchAsyncResponseTypeEntry + Entry *Entry + // Referral is the received referral, only set if Type is SearchAsyncResponseTypeReferral + Referral string + // Control is the received control, only set if Type is SearchAsyncResponseTypeControl + Control Control + // closed indicates that the request is finished + closed bool + // err holds the encountered error while processing server's response, if any + err error +} + +// Closed returns true if the request is finished +func (r *SearchAsyncResponse) Closed() bool { + return r.closed +} + +// Err returns the encountered error while processing server's response, if any +func (r *SearchAsyncResponse) Err() error { + return r.err +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -405,10 +441,11 @@ func NewSearchRequest( // SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the // search request. All paged LDAP query responses will be buffered and the final result will be returned atomically. // The following four cases are possible given the arguments: -// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size -// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries -// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request -// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size +// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// // A requested pagingSize of 0 is interpreted as no limit by LDAP servers. func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { var pagingControl *ControlPaging @@ -519,6 +556,80 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs the given search request asynchronously, it takes an optional done channel to stop the request. It returns a SearchAsyncResponse channel which will be +// closed when the request finished and an error, not nil if the request to the server failed +func (l *Conn) SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error) { + if done == nil { + done = make(chan struct{}) + } + msgCtx, err := l.doRequest(searchRequest) + if err != nil { + return nil, err + } + responses := make(chan *SearchAsyncResponse) + ch := make(chan *SearchAsyncResponse) + rcv := func() { + for { + packet, err := l.readPacket(msgCtx) + if err != nil { + ch <- &SearchAsyncResponse{closed: true, err: err} + return + } + + switch packet.Children[1].Tag { + case 4: + entry := &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + } + ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeEntry, Entry: entry} + case 5: + err := GetLDAPError(packet) + if err != nil { + ch <- &SearchAsyncResponse{closed: true, err: err} + return + } + var response SearchAsyncResponse + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + responses <- &SearchAsyncResponse{closed: true, err: fmt.Errorf("failed to decode child control: %s", err)} + return + } + response = SearchAsyncResponse{Type: SearchAsyncResponseTypeControl, Control: decodedChild} + } + } + response.closed = true + ch <- &response + return + case 19: + ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeReferral, Referral: packet.Children[1].Children[0].Value.(string)} + } + } + } + go func() { + defer l.finishMessage(msgCtx) + defer close(responses) + go rcv() + for { + select { + case <-done: + responses <- &SearchAsyncResponse{ + closed: true, + } + return + case res := <-ch: + responses <- res + if res.Closed() { + return + } + } + } + }() + return responses, nil +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute {