Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove parent task key #112

Merged
merged 14 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]()

## Removed

- field `parent_task_keys` in `ComputeTask` ([#112](https://github.com/Substra/orchestrator/pull/112))
- view `expanded_compute_tasks` ([#112](https://github.com/Substra/orchestrator/pull/112))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just put #112, github does the link to the PR automatically :)


## [0.31.0] - 2022-12-19

Expand Down
1 change: 1 addition & 0 deletions chaincode/ledger/dbal.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type storedAsset struct {
// dbal indexes
const computePlanTaskStatusIndex = "computePlan~computePlanKey~status~task"
const computeTaskParentIndex = "computeTask~parentTask~key"
const computeTaskChildIndex = "computeTask~childTask~key"
const modelTaskKeyIndex = "model~taskKey~modelKey"
const performanceIndex = "performance~taskKey~metricKey"
const allOrganizationsIndex = "organizations~id"
Expand Down
28 changes: 27 additions & 1 deletion chaincode/ledger/dbal_computetask.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/substra/orchestrator/lib/common"
orcerrors "github.com/substra/orchestrator/lib/errors"
"github.com/substra/orchestrator/lib/persistence"
"github.com/substra/orchestrator/lib/service"
"github.com/substra/orchestrator/utils"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -38,11 +39,15 @@ func (db *DB) addComputeTask(t *asset.ComputeTask) error {
if err != nil {
return err
}
for _, parentTask := range t.ParentTaskKeys {
for _, parentTask := range service.GetParentTaskKeys(t.Inputs) {
err = db.createIndex(computeTaskParentIndex, []string{asset.ComputeTaskKind, parentTask, t.Key})
if err != nil {
return err
}
err = db.createIndex(computeTaskChildIndex, []string{asset.ComputeTaskKind, t.Key, parentTask})
if err != nil {
return err
}
}
return nil
}
Expand Down Expand Up @@ -171,6 +176,27 @@ func (db *DB) GetComputeTaskChildren(key string) ([]*asset.ComputeTask, error) {
return tasks, nil
}

// GetComputeTaskParents returns the children of the task identified by the given key
func (db *DB) GetComputeTaskParents(key string) ([]*asset.ComputeTask, error) {
elementKeys, err := db.getIndexKeys(computeTaskChildIndex, []string{asset.ComputeTaskKind, key})
if err != nil {
return nil, err
}

db.logger.Debug().Int("numParents", len(elementKeys)).Msg("GetComputeTaskParents")

tasks := []*asset.ComputeTask{}
for _, parentKey := range elementKeys {
task, err := db.GetComputeTask(parentKey)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}

return tasks, nil
}

// GetComputePlanTasksKeys returns the list of task keys from the provided compute plan
func (db *DB) GetComputePlanTasksKeys(key string) ([]string, error) {
keys, err := db.getIndexKeys(computePlanTaskStatusIndex, []string{asset.ComputePlanKind, key})
Expand Down
6 changes: 2 additions & 4 deletions lib/asset/computetask.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ message NewComputeTaskOutput {
// ComputeTask is a computation step in a ComputePlan.
// It was previously called XXXtuple: Traintuple, CompositeTraintuple, etc
message ComputeTask {
reserved 3, 12, 13, 14, 15, 18;
reserved "algo", "data", "test", "train", "composite", "aggregate", "predict";
reserved 3, 6, 12, 13, 14, 15, 18;
reserved "algo", "data", "test", "train", "composite", "aggregate", "parent_task_keys", "predict";

string key = 1;
ComputeTaskCategory category = 2;
string owner = 4;
string compute_plan_key = 5;
// Keys of parent ComputeTasks
repeated string parent_task_keys = 6;
int32 rank = 7;
ComputeTaskStatus status = 8; // mutable
string worker = 9;
Expand Down
1 change: 1 addition & 0 deletions lib/persistence/computetask_dbal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type ComputeTaskDBAL interface {
UpdateComputeTaskStatus(taskKey string, taskStatus asset.ComputeTaskStatus) error
QueryComputeTasks(p *common.Pagination, filter *asset.TaskQueryFilter) ([]*asset.ComputeTask, common.PaginationToken, error)
GetComputeTaskChildren(key string) ([]*asset.ComputeTask, error)
GetComputeTaskParents(key string) ([]*asset.ComputeTask, error)
// GetComputePlanTasks returns the tasks of the compute plan identified by the given key
GetComputePlanTasks(key string) ([]*asset.ComputeTask, error)
GetComputePlanTasksKeys(key string) ([]string, error)
Expand Down
13 changes: 6 additions & 7 deletions lib/service/computetask.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ func (s *ComputeTaskService) sortTasks(newTasks []*asset.NewComputeTask, existin
unsortedParentsCount[unsortedTasks[i].Key] = 0
// We count the number of parents that are not already registered in the persistence layer

for _, parent := range getParentTaskKeys(unsortedTasks[i].Inputs) {
for _, parent := range GetParentTaskKeys(unsortedTasks[i].Inputs) {
if !utils.SliceContains(existingTasks, parent) {
unsortedParentsCount[unsortedTasks[i].Key]++
}
Expand All @@ -326,7 +326,7 @@ func (s *ComputeTaskService) sortTasks(newTasks []*asset.NewComputeTask, existin
sortedTasksCount++

for i := 0; i < len(unsortedTasks); i++ {
for _, key := range getParentTaskKeys(unsortedTasks[i].Inputs) {
for _, key := range GetParentTaskKeys(unsortedTasks[i].Inputs) {
if key == currentTask.Key {
unsortedParentsCount[unsortedTasks[i].Key]--
if unsortedParentsCount[unsortedTasks[i].Key] == 0 {
Expand Down Expand Up @@ -365,7 +365,7 @@ func (s *ComputeTaskService) createTask(input *asset.NewComputeTask, owner strin
return nil, orcerrors.NewPermissionDenied("Cannot register tasks to a compute plan you don't own")
}

parentTasks, err := s.getRegisteredTasks(getParentTaskKeys(input.Inputs)...)
parentTasks, err := s.getRegisteredTasks(GetParentTaskKeys(input.Inputs)...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -420,7 +420,6 @@ func (s *ComputeTaskService) createTask(input *asset.NewComputeTask, owner strin
Metadata: input.Metadata,
Status: status,
Rank: getRank(parentTasks),
ParentTaskKeys: getParentTaskKeys(input.Inputs),
CreationDate: timestamppb.New(s.GetTimeService().GetTransactionTime()),
Inputs: input.Inputs,
Outputs: outputs,
Expand Down Expand Up @@ -712,7 +711,7 @@ func (s *ComputeTaskService) getExistingParentKeys(tasks []*asset.NewComputeTask
parents := []string{}

for _, task := range tasks {
parents = append(parents, getParentTaskKeys(task.Inputs)...)
parents = append(parents, GetParentTaskKeys(task.Inputs)...)
}

return s.GetComputeTaskDBAL().GetExistingComputeTaskKeys(parents)
Expand Down Expand Up @@ -910,8 +909,8 @@ func (s *ComputeTaskService) getTaskOutputCounter(taskKey string) (persistence.C
return s.GetComputeTaskDBAL().CountComputeTaskRegisteredOutputs(taskKey)
}

// getParentTaskKeys returns the parent task keys based on task inputs
func getParentTaskKeys(inputs []*asset.ComputeTaskInput) []string {
// GetParentTaskKeys returns the parent task keys based on task inputs
func GetParentTaskKeys(inputs []*asset.ComputeTaskInput) []string {
seen := make(map[string]struct{})
parentKeys := []string{}
for _, input := range inputs {
Expand Down
4 changes: 1 addition & 3 deletions lib/service/computetask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ func TestRegisterTrainTask(t *testing.T) {
ComputePlanKey: newTrainTask.ComputePlanKey,
Metadata: newTrainTask.Metadata,
Status: asset.ComputeTaskStatus_STATUS_TODO,
ParentTaskKeys: []string{},
Worker: dataManager.Owner,
Inputs: newTrainTask.Inputs,
CreationDate: timestamppb.New(time.Unix(1337, 0)),
Expand Down Expand Up @@ -440,7 +439,6 @@ func TestRegisterCompositeTaskWithCompositeParents(t *testing.T) {
ComputePlanKey: newTask.ComputePlanKey,
Metadata: newTask.Metadata,
Status: asset.ComputeTaskStatus_STATUS_WAITING,
ParentTaskKeys: []string{parent1.Key, parent2.Key},
Worker: dataManager.Owner,
Rank: 1,
CreationDate: timestamppb.New(time.Unix(1337, 0)),
Expand Down Expand Up @@ -1654,7 +1652,7 @@ func TestGetParentTaskKeys(t *testing.T) {
t.Run(
fmt.Sprintf("parent task keys from inputs case %d", i),
func(t *testing.T) {
assert.Equal(t, c.keys, getParentTaskKeys(c.inputs))
assert.Equal(t, c.keys, GetParentTaskKeys(c.inputs))
},
)
}
Expand Down
17 changes: 6 additions & 11 deletions lib/service/computetaskstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,12 @@ func (s *ComputeTaskService) propagateDone(triggeringParent, child *asset.Comput
return nil
}

// loop over parent, only change status if all parents are DONE
for _, parentKey := range child.ParentTaskKeys {
if parentKey == triggeringParent.Key {
// We already know this one is DONE
continue
}
parent, err := s.GetComputeTaskDBAL().GetComputeTask(parentKey)
if err != nil {
return err
}
parents, err := s.GetComputeTaskDBAL().GetComputeTaskParents(child.Key)
if err != nil {
return err
}

for _, parent := range parents {
if parent.Status != asset.ComputeTaskStatus_STATUS_DONE {
logger.Debug().
Str("parent", parent.Key).
Expand All @@ -204,7 +199,7 @@ func (s *ComputeTaskService) propagateDone(triggeringParent, child *asset.Comput
return nil
}
}
err := s.applyTaskAction(child, transitionTodo, fmt.Sprintf("Last parent task %s done", triggeringParent.Key))
err = s.applyTaskAction(child, transitionTodo, fmt.Sprintf("Last parent task %s done", triggeringParent.Key))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my info: why remove the ":" here ?

if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions lib/service/computetaskstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ func TestCascadeStatusDone(t *testing.T) {
Worker: "worker",
}
// Check for children to be updated
dbal.On("GetComputeTaskParents", "child").Return([]*asset.ComputeTask{
{Key: "uuid", Status: asset.ComputeTaskStatus_STATUS_DONE},
}, nil)
dbal.On("GetComputeTaskChildren", "uuid").Return([]*asset.ComputeTask{
{Key: "child", Status: asset.ComputeTaskStatus_STATUS_WAITING},
}, nil)
Expand Down
66 changes: 53 additions & 13 deletions server/standalone/dbal/computetask.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/substra/orchestrator/lib/common"
orcerrors "github.com/substra/orchestrator/lib/errors"
"github.com/substra/orchestrator/lib/persistence"
"github.com/substra/orchestrator/lib/service"
"github.com/substra/orchestrator/utils"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
Expand All @@ -26,7 +27,6 @@ type sqlComputeTask struct {
AlgoKey string
Owner string
ComputePlanKey string
ParentTaskKeys []string
Rank int32
Status asset.ComputeTaskStatus
Worker string
Expand All @@ -43,7 +43,6 @@ func (t *sqlComputeTask) toComputeTask() (*asset.ComputeTask, error) {
task.AlgoKey = t.AlgoKey
task.Owner = t.Owner
task.ComputePlanKey = t.ComputePlanKey
task.ParentTaskKeys = t.ParentTaskKeys
task.Rank = t.Rank
task.Status = t.Status
task.Worker = t.Worker
Expand Down Expand Up @@ -131,13 +130,14 @@ func getCopyableComputeTaskValues(channel string, task *asset.ComputeTask) ([]in
func (d *DBAL) insertParentTasks(tasks ...*asset.ComputeTask) error {
parentRows := make([][]interface{}, 0)
for _, t := range tasks {
if t.ParentTaskKeys != nil {
parentTasks := service.GetParentTaskKeys(t.Inputs)
if parentTasks != nil {
childTask, err := uuid.Parse(t.GetKey())
if err != nil {
return err
}

for idx, parentTaskKey := range t.ParentTaskKeys {
for idx, parentTaskKey := range parentTasks {
parentTask, err := uuid.Parse(parentTaskKey)
if err != nil {
return err
Expand Down Expand Up @@ -207,8 +207,8 @@ func (d *DBAL) GetExistingComputeTaskKeys(keys []string) ([]string, error) {
func (d *DBAL) GetComputeTask(key string) (*asset.ComputeTask, error) {
stmt := getStatementBuilder().
Select("key", "compute_plan_key", "status", "category", "worker", "owner", "rank", "creation_date",
"logs_permission", "metadata", "algo_key", "parent_task_keys").
From("expanded_compute_tasks").
"logs_permission", "metadata", "algo_key").
From("compute_tasks").
Where(sq.Eq{"channel": d.channel, "key": key})

row, err := d.queryRow(stmt)
Expand All @@ -218,7 +218,7 @@ func (d *DBAL) GetComputeTask(key string) (*asset.ComputeTask, error) {

ct := new(sqlComputeTask)
err = row.Scan(&ct.Key, &ct.ComputePlanKey, &ct.Status, &ct.Category, &ct.Worker, &ct.Owner, &ct.Rank, &ct.CreationDate,
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey, &ct.ParentTaskKeys)
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, orcerrors.NewNotFound("computetask", key)
Expand All @@ -244,8 +244,8 @@ func (d *DBAL) GetComputeTask(key string) (*asset.ComputeTask, error) {
func (d *DBAL) GetComputeTaskChildren(key string) ([]*asset.ComputeTask, error) {
stmt := getStatementBuilder().
Select("key", "compute_plan_key", "status", "category", "worker", "owner", "rank", "creation_date",
"logs_permission", "metadata", "algo_key", "parent_task_keys").
From("expanded_compute_tasks t").
"logs_permission", "metadata", "algo_key").
From("compute_tasks t").
Join("compute_task_parents p ON t.key = p.child_task_key").
Where(sq.Eq{"t.channel": d.channel, "p.parent_task_key": key}).
OrderByClause("p.position ASC")
Expand All @@ -262,7 +262,47 @@ func (d *DBAL) GetComputeTaskChildren(key string) ([]*asset.ComputeTask, error)

err = rows.Scan(
&ct.Key, &ct.ComputePlanKey, &ct.Status, &ct.Category, &ct.Worker, &ct.Owner, &ct.Rank, &ct.CreationDate,
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey, &ct.ParentTaskKeys)
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey)
if err != nil {
return nil, err
}

task, err := ct.toComputeTask()
if err != nil {
return nil, err
}

tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, err
}

return tasks, nil
}

func (d *DBAL) GetComputeTaskParents(key string) ([]*asset.ComputeTask, error) {
stmt := getStatementBuilder().
Select("key", "compute_plan_key", "status", "category", "worker", "owner", "rank", "creation_date",
"logs_permission", "metadata", "algo_key").
From("compute_tasks t").
Join("compute_task_parents p ON t.key = p.parent_task_key").
Where(sq.Eq{"t.channel": d.channel, "p.child_task_key": key}).
OrderByClause("p.position ASC")

rows, err := d.query(stmt)
if err != nil {
return nil, err
}
defer rows.Close()

tasks := []*asset.ComputeTask{}
for rows.Next() {
ct := new(sqlComputeTask)

err = rows.Scan(
&ct.Key, &ct.ComputePlanKey, &ct.Status, &ct.Category, &ct.Worker, &ct.Owner, &ct.Rank, &ct.CreationDate,
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -345,8 +385,8 @@ func (d *DBAL) CountComputeTaskRegisteredOutputs(key string) (persistence.Comput
func (d *DBAL) queryBaseComputeTasks(pagination *common.Pagination, filterer func(sq.SelectBuilder) sq.SelectBuilder) ([]*asset.ComputeTask, common.PaginationToken, error) {
stmt := getStatementBuilder().
Select("key", "compute_plan_key", "status", "category", "worker", "owner", "rank", "creation_date",
"logs_permission", "metadata", "algo_key", "parent_task_keys").
From("expanded_compute_tasks").
"logs_permission", "metadata", "algo_key").
From("compute_tasks").
Where(sq.Eq{"channel": d.channel}).
OrderByClause("creation_date ASC, key")

Expand Down Expand Up @@ -386,7 +426,7 @@ func (d *DBAL) queryBaseComputeTasks(pagination *common.Pagination, filterer fun

err = rows.Scan(
&ct.Key, &ct.ComputePlanKey, &ct.Status, &ct.Category, &ct.Worker, &ct.Owner, &ct.Rank, &ct.CreationDate,
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey, &ct.ParentTaskKeys)
&ct.LogsPermission, &ct.Metadata, &ct.AlgoKey)
if err != nil {
return nil, "", err
}
Expand Down
Loading