Skip to content

Commit

Permalink
feat(sampler): add --ratios
Browse files Browse the repository at this point in the history
  • Loading branch information
breakthewall committed Jul 26, 2024
1 parent 5844464 commit c518895
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions icfree/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from pyDOE2 import lhs
import ast

def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=None):
def generate_lhs_samples(input_file, num_samples, step=None, ratios=None, fixed_values=None, seed=None):
"""
Generates Latin Hypercube Samples for components based on discrete ranges.
Parameters:
- input_file: Path to the input file containing components and their max values.
- num_samples: Number of samples to generate.
- step: Step size for creating discrete ranges.
- ratios: List of ratios for creating discrete ranges.
- fixed_values: Dictionary of components with fixed values (optional).
- seed: Random seed for reproducibility.
Expand All @@ -33,11 +34,17 @@ def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=
# Generate discrete ranges for each component
for index, row in components_df.iterrows():
component_name = row['Component']
max_value = row['maxValue']
if fixed_values and component_name in fixed_values:
# If the component has a fixed value, use a single-element array
component_range = np.array([fixed_values[component_name]])
else:
component_range = np.arange(0, row['maxValue'] + step, step)
if ratios is not None:
# Use ratios to create the discrete range
component_range = np.array([r * max_value for r in ratios])
else:
# Use step to create the discrete range
component_range = np.arange(0, max_value + step, step)
discrete_ranges.append(component_range)

# Determine the number of components
Expand All @@ -55,15 +62,16 @@ def generate_lhs_samples(input_file, num_samples, step, fixed_values=None, seed=
samples_df = pd.DataFrame(samples, columns=components_df['Component'])
return samples_df

def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed=None):
def main(input_file, output_file, num_samples, step=None, ratios=None, fixed_values=None, seed=None):
"""
Main function to generate LHS samples and save them to a CSV file.
Parameters:
- input_file: Path to the input file containing components and their max values.
- output_file: Path to the output CSV file where samples will be written.
- num_samples: Number of samples to generate.
- step: Step size for creating discrete ranges (default: 2.5).
- step: Step size for creating discrete ranges (optional).
- ratios: List of ratios for creating discrete ranges (optional).
- fixed_values: Dictionary of components with fixed values (optional).
- seed: Random seed for reproducibility (optional).
"""
Expand All @@ -80,7 +88,7 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed
print(f"Warning: Component '{component}' not found in the input file.")

# Generate LHS samples
samples_df = generate_lhs_samples(input_file, num_samples, step, fixed_values, seed)
samples_df = generate_lhs_samples(input_file, num_samples, step, ratios, fixed_values, seed)

# Write the samples to a CSV file
samples_df.to_csv(output_file, index=False)
Expand All @@ -92,7 +100,12 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed
parser.add_argument('input_file', type=str, help='Input file path with components and their max values.')
parser.add_argument('output_file', type=str, help='Output CSV file path for the samples.')
parser.add_argument('num_samples', type=int, help='Number of samples to generate.')
parser.add_argument('--step', type=float, default=2.5, help='Step size for creating discrete ranges (default: 2.5).')

# Create a mutually exclusive group for step and ratios
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--step', type=float, help='Step size for creating discrete ranges.')
group.add_argument('--ratios', type=str, help='Comma-separated list of ratios for creating discrete ranges (e.g., "0,0.2,0.4,0.6,0.8,1").')

parser.add_argument('--fixed_values', type=str, default=None, help='Fixed values for components as a dictionary (e.g., \'{"Component1": 10, "Component2": 20}\')')
parser.add_argument('--seed', type=int, default=None, help='Seed for random number generation for reproducibility (optional).')

Expand All @@ -102,5 +115,8 @@ def main(input_file, output_file, num_samples, step=2.5, fixed_values=None, seed
# Convert fixed_values argument from string to dictionary if provided
fixed_values = ast.literal_eval(args.fixed_values) if args.fixed_values else None

# Convert ratios argument from comma-separated string to list of floats if provided
ratios = [float(r) for r in args.ratios.split(',')] if args.ratios else None

# Run the main function with the parsed arguments
main(args.input_file, args.output_file, args.num_samples, args.step, fixed_values, args.seed)
main(args.input_file, args.output_file, args.num_samples, args.step, ratios, fixed_values, args.seed)

0 comments on commit c518895

Please sign in to comment.