Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed type_for_source logic #152

Merged
merged 7 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions cwl_utils/cwl_expression_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
Union,
)

from ruamel import yaml
from ruamel.yaml.main import YAML
from ruamel.yaml.scalarstring import walk_tree

from cwl_utils.errors import WorkflowException
from cwl_utils.loghandler import _logger as _cwlutilslogger

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,10 +105,13 @@ def main(args: Optional[List[str]] = None) -> int:

def run(args: argparse.Namespace) -> int:
"""Primary processing loop."""
return_code = 0
yaml = YAML(typ="rt")
yaml.preserve_quotes = True # type: ignore[assignment]
for document in args.inputs:
_logger.info("Processing %s.", document)
with open(document) as doc_handle:
result = yaml.main.round_trip_load(doc_handle, preserve_quotes=True)
result = yaml.load(doc_handle)
version = result["cwlVersion"]
uri = Path(document).resolve().as_uri()
if version == "v1.0":
Expand All @@ -128,39 +133,44 @@ def run(args: argparse.Namespace) -> int:
"Sorry, %s is not a supported CWL version by this tool.", version
)
return -1
result, modified = traverse(
top, not args.etools, False, args.skip_some1, args.skip_some2
)
output = Path(args.dir) / Path(document).name
if not modified:
if len(args.inputs) > 1:
shutil.copyfile(document, output)
continue
else:
return 7
if not isinstance(result, MutableSequence):
result_json = save(
result,
base_url=result.loadingOptions.fileuri
if result.loadingOptions.fileuri
else "",
try:
result, modified = traverse(
top, not args.etools, False, args.skip_some1, args.skip_some2
)
# ^^ Setting the base_url and keeping the default value
# for relative_uris=True means that the IDs in the generated
# JSON/YAML are kept clean of the path to the input document
else:
result_json = [
save(result_item, base_url=result_item.loadingOptions.fileuri)
for result_item in result
]
yaml.scalarstring.walk_tree(result_json)
# ^ converts multiline strings to nice multiline YAML
with open(output, "w", encoding="utf-8") as output_filehandle:
output_filehandle.write(
"#!/usr/bin/env cwl-runner\n"
) # TODO: teach the codegen to do this?
yaml.main.round_trip_dump(result_json, output_filehandle)
return 0
output = Path(args.dir) / Path(document).name
if not modified:
if len(args.inputs) > 1:
shutil.copyfile(document, output)
continue
else:
return 7
if not isinstance(result, MutableSequence):
result_json = save(
result,
base_url=result.loadingOptions.fileuri
if result.loadingOptions.fileuri
else "",
)
# ^^ Setting the base_url and keeping the default value
# for relative_uris=True means that the IDs in the generated
# JSON/YAML are kept clean of the path to the input document
else:
result_json = [
save(result_item, base_url=result_item.loadingOptions.fileuri)
for result_item in result
]
walk_tree(result_json)
# ^ converts multiline strings to nice multiline YAML
with open(output, "w", encoding="utf-8") as output_filehandle:
output_filehandle.write(
"#!/usr/bin/env cwl-runner\n"
) # TODO: teach the codegen to do this?
yaml.dump(result_json, output_filehandle)
except WorkflowException as exc:
return_code = 1
_logger.exception("Skipping %s due to error.", document, exc_info=exc)

return return_code


if __name__ == "__main__":
Expand Down
90 changes: 17 additions & 73 deletions cwl_utils/cwl_v1_0_expression_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from schema_salad.utils import json_dumps

import cwl_utils.parser.cwl_v1_0 as cwl
import cwl_utils.parser.cwl_v1_0_utils as utils
from cwl_utils.errors import JavascriptException, WorkflowException
from cwl_utils.expression import do_eval, interpolate
from cwl_utils.types import CWLObjectType, CWLOutputType
Expand Down Expand Up @@ -537,7 +538,7 @@ def empty_inputs(
else:
try:
result[param_id] = example_input(
type_for_source(process_or_step.run, param.source, parent)
utils.type_for_source(process_or_step.run, param.source, parent)
)
except WorkflowException:
pass
Expand Down Expand Up @@ -582,71 +583,6 @@ def example_input(some_type: Any) -> Any:
return None


def type_for_source(
process: Union[cwl.CommandLineTool, cwl.Workflow, cwl.ExpressionTool],
sourcenames: Union[str, List[str]],
parent: Optional[cwl.Workflow] = None,
) -> Union[List[Any], Any]:
"""Determine the type for the given sourcenames."""
params = param_for_source_id(process, sourcenames, parent)
if not isinstance(params, list):
return params.type
new_type: List[Any] = []
for p in params:
if isinstance(p, str) and p not in new_type:
new_type.append(p)
elif hasattr(p, "type") and p.type not in new_type:
new_type.append(p.type)
return new_type


def param_for_source_id(
process: Union[cwl.CommandLineTool, cwl.Workflow, cwl.ExpressionTool],
sourcenames: Union[str, List[str]],
parent: Optional[cwl.Workflow] = None,
) -> Union[List[cwl.InputParameter], cwl.InputParameter]:
"""Find the process input parameter that matches one of the given sourcenames."""
if isinstance(sourcenames, str):
sourcenames = [sourcenames]
params: List[cwl.InputParameter] = []
for sourcename in sourcenames:
if not isinstance(process, cwl.Workflow):
for param in process.inputs:
if param.id.split("#")[-1] == sourcename.split("#")[-1]:
params.append(param)
targets = [process]
if parent:
targets.append(parent)
for target in targets:
if isinstance(target, cwl.Workflow):
for inp in target.inputs:
if inp.id.split("#")[-1] == sourcename.split("#")[-1]:
params.append(inp)
for step in target.steps:
if sourcename.split("/")[0] == step.id.split("#")[-1] and step.out:
for outp in step.out:
outp_id = outp if isinstance(outp, str) else outp.id
if outp_id.split("/")[-1] == sourcename.split("/", 1)[1]:
if step.run and step.run.outputs:
for output in step.run.outputs:
if (
output.id.split("#")[-1]
== sourcename.split("/", 1)[1]
):
params.append(output)
if len(params) == 1:
return params[0]
elif len(params) > 1:
return params
raise WorkflowException(
"param {} not found in {}\n or\n {}.".format(
sourcename,
yaml.main.round_trip_dump(cwl.save(process)),
yaml.main.round_trip_dump(cwl.save(parent)),
)
)


EMPTY_FILE: CWLOutputType = {
"class": "File",
"basename": "em.pty",
Expand Down Expand Up @@ -1841,11 +1777,13 @@ def traverse_step(
if not step.scatter:
self.append(
example_input(
type_for_source(parent, source.split("#")[-1])
utils.type_for_source(parent, source.split("#")[-1])
)
)
else:
scattered_source_type = type_for_source(parent, source)
scattered_source_type = utils.type_for_source(
parent, source
)
if isinstance(scattered_source_type, list):
for stype in scattered_source_type:
self.append(example_input(stype.type))
Expand All @@ -1854,10 +1792,12 @@ def traverse_step(
else:
if not step.scatter:
self = example_input(
type_for_source(parent, inp.source.split("#")[-1])
utils.type_for_source(parent, inp.source.split("#")[-1])
)
else:
scattered_source_type2 = type_for_source(parent, inp.source)
scattered_source_type2 = utils.type_for_source(
parent, inp.source
)
if isinstance(scattered_source_type2, list):
self = example_input(scattered_source_type2[0].type)
else:
Expand All @@ -1880,7 +1820,9 @@ def traverse_step(
for source in inp.source:
source_id = source.split("#")[-1]
input_source_id.append(source_id)
temp_type = type_for_source(step.run, source_id, parent)
temp_type = utils.type_for_source(
step.run, source_id, parent
)
if isinstance(temp_type, list):
for ttype in temp_type:
if ttype not in source_types:
Expand All @@ -1894,7 +1836,7 @@ def traverse_step(
)
else:
input_source_id = inp.source.split("#")[-1]
source_type = param_for_source_id(
source_type = utils.param_for_source_id(
step.run, input_source_id, parent
)
# target.id = target.id.split('#')[-1]
Expand Down Expand Up @@ -1965,7 +1907,9 @@ def workflow_step_to_InputParameters(
continue
inp_id = inp.id.split("#")[-1].split("/")[-1]
if inp.source and inp_id != except_in_id:
param = copy.deepcopy(param_for_source_id(parent, sourcenames=inp.source))
param = copy.deepcopy(
utils.param_for_source_id(parent, sourcenames=inp.source)
)
if isinstance(param, list):
for p in param:
if not p.type:
Expand Down
Loading