From c518895d8974662309c1b2334b19f3413732594f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joan=20H=C3=A9risson?= Date: Fri, 26 Jul 2024 16:58:24 +0200 Subject: [PATCH] feat(sampler): add --ratios --- icfree/sampler.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/icfree/sampler.py b/icfree/sampler.py index 8b685ed..bb42d02 100644 --- a/icfree/sampler.py +++ b/icfree/sampler.py @@ -5,7 +5,7 @@ from pyDOE2 import lhs import ast -def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=None): +def generate_lhs_samples(input_file, num_samples, step=None, ratios=None, fixed_values=None, seed=None): """ Generates Latin Hypercube Samples for components based on discrete ranges. @@ -13,6 +13,7 @@ def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed= - input_file: Path to the input file containing components and their max values. - num_samples: Number of samples to generate. - step: Step size for creating discrete ranges. + - ratios: List of ratios for creating discrete ranges. - fixed_values: Dictionary of components with fixed values (optional). - seed: Random seed for reproducibility. @@ -33,11 +34,17 @@ def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed= # Generate discrete ranges for each component for index, row in components_df.iterrows(): component_name = row['Component'] + max_value = row['maxValue'] if fixed_values and component_name in fixed_values: # If the component has a fixed value, use a single-element array component_range = np.array([fixed_values[component_name]]) else: - component_range = np.arange(0, row['maxValue'] + step, step) + if ratios is not None: + # Use ratios to create the discrete range + component_range = np.array([r * max_value for r in ratios]) + else: + # Use step to create the discrete range + component_range = np.arange(0, max_value + step, step) discrete_ranges.append(component_range) # Determine the number of components @@ -55,7 +62,7 @@ def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed= samples_df = pd.DataFrame(samples, columns=components_df['Component']) return samples_df -def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed=None): +def main(input_file, output_file, num_samples, step=None, ratios=None, fixed_values=None, seed=None): """ Main function to generate LHS samples and save them to a CSV file. @@ -63,7 +70,8 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed - input_file: Path to the input file containing components and their max values. - output_file: Path to the output CSV file where samples will be written. - num_samples: Number of samples to generate. - - step: Step size for creating discrete ranges (default: 2.5). + - step: Step size for creating discrete ranges (optional). + - ratios: List of ratios for creating discrete ranges (optional). - fixed_values: Dictionary of components with fixed values (optional). - seed: Random seed for reproducibility (optional). """ @@ -80,7 +88,7 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed print(f"Warning: Component '{component}' not found in the input file.") # Generate LHS samples - samples_df = generate_lhs_samples(input_file, num_samples, step, fixed_values, seed) + samples_df = generate_lhs_samples(input_file, num_samples, step, ratios, fixed_values, seed) # Write the samples to a CSV file samples_df.to_csv(output_file, index=False) @@ -92,7 +100,12 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed parser.add_argument('input_file', type=str, help='Input file path with components and their max values.') parser.add_argument('output_file', type=str, help='Output CSV file path for the samples.') parser.add_argument('num_samples', type=int, help='Number of samples to generate.') - parser.add_argument('--step', type=float, default=2.5, help='Step size for creating discrete ranges (default: 2.5).') + + # Create a mutually exclusive group for step and ratios + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--step', type=float, help='Step size for creating discrete ranges.') + group.add_argument('--ratios', type=str, help='Comma-separated list of ratios for creating discrete ranges (e.g., "0,0.2,0.4,0.6,0.8,1").') + parser.add_argument('--fixed_values', type=str, default=None, help='Fixed values for components as a dictionary (e.g., \'{"Component1": 10, "Component2": 20}\')') parser.add_argument('--seed', type=int, default=None, help='Seed for random number generation for reproducibility (optional).') @@ -102,5 +115,8 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed # Convert fixed_values argument from string to dictionary if provided fixed_values = ast.literal_eval(args.fixed_values) if args.fixed_values else None + # Convert ratios argument from comma-separated string to list of floats if provided + ratios = [float(r) for r in args.ratios.split(',')] if args.ratios else None + # Run the main function with the parsed arguments - main(args.input_file, args.output_file, args.num_samples, args.step, fixed_values, args.seed) + main(args.input_file, args.output_file, args.num_samples, args.step, ratios, fixed_values, args.seed)