Skip to content

Commit

Permalink
fix(oscal): single model write operations support (#502)
Browse files Browse the repository at this point in the history
* fix(oscal): remove mutli-model write operations

* fix(oscall): testing for GetOscalModel()

* fix(oscal): remove assessment test file

---------

Co-authored-by: Cole (Mike) Winberry <86802655+mike-winberry@users.noreply.github.com>
Co-authored-by: Andy Mills <61879371+CloudBeard@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent 0d69a45 commit 3646650
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 10 deletions.
76 changes: 66 additions & 10 deletions src/pkg/common/oscal/complete-schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oscal
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"

Expand Down Expand Up @@ -32,13 +33,18 @@ func NewOscalModel(data []byte) (*oscalTypes_1_1_2.OscalModels, error) {
// supports both json and yaml
func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error {

// if no path or directory add default filename
if filepath.Ext(filePath) == "" {
filePath = filepath.Join(filePath, "oscal.yaml")
modelType, err := GetOscalModel(model)
if err != nil {
return err
}

if err := files.IsJsonOrYaml(filePath); err != nil {
return err
// if no path or directory add default filename
if filepath.Ext(filePath) == "" {
filePath = filepath.Join(filePath, fmt.Sprintf("%s.yaml", modelType))
} else {
if err := files.IsJsonOrYaml(filePath); err != nil {
return err
}
}

if _, err := os.Stat(filePath); err == nil {
Expand All @@ -51,9 +57,18 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error
if err != nil {
return err
}

existingModelType, err := GetOscalModel(existingModel)
if err != nil {
return nil
}

if existingModelType != modelType {
return fmt.Errorf("cannot merge model %s with existing model %s", modelType, existingModelType)
}
// Merge the existing model with the new model
// re-assign to perform common operations below
model, err = MergeOscalModels(existingModel, model)
model, err = MergeOscalModels(existingModel, model, modelType)
if err != nil {
return err
}
Expand All @@ -71,7 +86,7 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error
yamlEncoder.Encode(model)
}

err := files.WriteOutput(b.Bytes(), filePath)
err = files.WriteOutput(b.Bytes(), filePath)
if err != nil {
return err
}
Expand All @@ -82,12 +97,12 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error

}

func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *oscalTypes_1_1_2.OscalModels) (*oscalTypes_1_1_2.OscalModels, error) {
func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *oscalTypes_1_1_2.OscalModels, modelType string) (*oscalTypes_1_1_2.OscalModels, error) {
var err error
// Now to check each model type - currently only component definition and assessment-results apply

// Component definition
if existingModel.ComponentDefinition != nil && newModel.ComponentDefinition != nil {
if modelType == "component" {
merged, err := MergeComponentDefinitions(existingModel.ComponentDefinition, newModel.ComponentDefinition)
if err != nil {
return nil, err
Expand All @@ -99,7 +114,7 @@ func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *osc
}

// Assessment Results
if existingModel.AssessmentResults != nil && newModel.AssessmentResults != nil {
if modelType == "assessment-results" {
merged, err := MergeAssessmentResults(existingModel.AssessmentResults, newModel.AssessmentResults)
if err != nil {
return existingModel, err
Expand All @@ -112,3 +127,44 @@ func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *osc

return existingModel, err
}

func GetOscalModel(model *oscalTypes_1_1_2.OscalModels) (modelType string, err error) {

// Check if one model present and all other nil - is there a better way to do this?
models := make([]string, 0)

if model.Catalog != nil {
models = append(models, "catalog")
}

if model.Profile != nil {
models = append(models, "profile")
}

if model.ComponentDefinition != nil {
models = append(models, "component")
}

if model.SystemSecurityPlan != nil {
models = append(models, "system-security-plan")
}

if model.AssessmentPlan != nil {
models = append(models, "assessment-plan")
}

if model.AssessmentResults != nil {
models = append(models, "assessment-results")
}

if model.PlanOfActionAndMilestones != nil {
models = append(models, "poam")
}

if len(models) > 1 {
return "", fmt.Errorf("%v models identified when only oneOf is permitted", len(models))
} else {
return models[0], nil
}

}
72 changes: 72 additions & 0 deletions src/pkg/common/oscal/complete-schema_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package oscal_test

import (
oscalTypes_1_1_2 "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2"
"github.com/defenseunicorns/lula/src/pkg/common/oscal"
"testing"
)

func TestGetOscalModel(t *testing.T) {
t.Parallel()

type TestCase struct {
Model oscalTypes_1_1_2.OscalModels
ModelType string
}

testCases := []TestCase{
{
Model: oscalTypes_1_1_2.OscalModels{
Catalog: &oscalTypes_1_1_2.Catalog{},
},
ModelType: "catalog",
},
{
Model: oscalTypes_1_1_2.OscalModels{
Profile: &oscalTypes_1_1_2.Profile{},
},
ModelType: "profile",
},
{
Model: oscalTypes_1_1_2.OscalModels{
ComponentDefinition: &oscalTypes_1_1_2.ComponentDefinition{},
},
ModelType: "component",
},
{
Model: oscalTypes_1_1_2.OscalModels{
SystemSecurityPlan: &oscalTypes_1_1_2.SystemSecurityPlan{},
},
ModelType: "system-security-plan",
},
{
Model: oscalTypes_1_1_2.OscalModels{
AssessmentPlan: &oscalTypes_1_1_2.AssessmentPlan{},
},
ModelType: "assessment-plan",
},
{
Model: oscalTypes_1_1_2.OscalModels{
AssessmentResults: &oscalTypes_1_1_2.AssessmentResults{},
},
ModelType: "assessment-results",
},
{
Model: oscalTypes_1_1_2.OscalModels{
PlanOfActionAndMilestones: &oscalTypes_1_1_2.PlanOfActionAndMilestones{},
},
ModelType: "poam",
},
}
for _, testCase := range testCases {
actual, err := oscal.GetOscalModel(&testCase.Model)
if err != nil {
t.Fatalf("unexpected error for model %s", testCase.ModelType)
}
expected := testCase.ModelType
if expected != actual {
t.Fatalf("error GetOscalModel: expected: %s | got: %s", expected, actual)
}
}

}

0 comments on commit 3646650

Please sign in to comment.