Skip to content

Commit

Permalink
Merge pull request #1073 from synthetichealth/rng
Browse files Browse the repository at this point in the history
Fix random divergence between runs with the same seeds.
  • Loading branch information
eedrummer committed May 27, 2022
2 parents 58225e0 + aa49445 commit d90b4e3
Show file tree
Hide file tree
Showing 31 changed files with 295 additions and 251 deletions.
101 changes: 32 additions & 69 deletions src/main/java/org/mitre/synthea/engine/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -36,6 +35,7 @@
import org.mitre.synthea.export.CDWExporter;
import org.mitre.synthea.export.Exporter;
import org.mitre.synthea.helpers.Config;
import org.mitre.synthea.helpers.DefaultRandomNumberGenerator;
import org.mitre.synthea.helpers.RandomNumberGenerator;
import org.mitre.synthea.helpers.TransitionMetrics;
import org.mitre.synthea.helpers.Utilities;
Expand All @@ -56,15 +56,16 @@
/**
* Generator creates a population by running the generic modules each timestep per Person.
*/
public class Generator implements RandomNumberGenerator {
public class Generator {

/**
* Unique ID for this instance of the Generator.
* Even if the same settings are used multiple times, this ID should be unique.
*/
public final UUID id = UUID.randomUUID();
public GeneratorOptions options;
private Random random;
private DefaultRandomNumberGenerator populationRandom;
private DefaultRandomNumberGenerator clinicianRandom;
public long timestep;
public long stop;
public long referenceTime;
Expand Down Expand Up @@ -218,7 +219,8 @@ private void init() {
CDWExporter.getInstance().setKeyStart((stateIndex * 1_000_000) + 1);
}

this.random = new Random(options.seed);
this.populationRandom = new DefaultRandomNumberGenerator(options.seed);
this.clinicianRandom = new DefaultRandomNumberGenerator(options.clinicianSeed);
this.timestep = Long.parseLong(Config.get("generate.timestep"));
this.stop = options.endTime;
this.referenceTime = options.referenceTime;
Expand Down Expand Up @@ -262,7 +264,7 @@ private void init() {
}

// initialize hospitals
Provider.loadProviders(location, options.clinicianSeed);
Provider.loadProviders(location, this.clinicianRandom);
// Initialize Payers
Payer.loadPayers(location);
// ensure modules load early
Expand Down Expand Up @@ -366,7 +368,7 @@ public void run() {
// Generate patients up to the specified population size.
for (int i = 0; i < this.options.population; i++) {
final int index = i;
final long seed = this.random.nextLong();
final long seed = this.populationRandom.randLong();
threadPool.submit(() -> generatePerson(index, seed));
}
}
Expand Down Expand Up @@ -398,6 +400,8 @@ public void run() {

System.out.printf("Records: total=%d, alive=%d, dead=%d\n", totalGeneratedPopulation.get(),
stats.get("alive").get(), stats.get("dead").get());
System.out.printf("RNG=%d\n", this.populationRandom.getCount());
System.out.printf("Clinician RNG=%d\n", this.clinicianRandom.getCount());

if (this.metrics != null) {
metrics.printStats(totalGeneratedPopulation.get(), Module.getModules(getModulePredicate()));
Expand Down Expand Up @@ -440,6 +444,7 @@ public List<FixedRecordGroup> importFixedPatientDemographicsFile() {
* @param index Target index in the whole set of people to generate
* @return generated Person
*/
@Deprecated
public Person generatePerson(int index) {
// System.currentTimeMillis is not unique enough
long personSeed = UUID.randomUUID().getMostSignificantBits() & Long.MAX_VALUE;
Expand All @@ -461,16 +466,15 @@ public Person generatePerson(int index) {
*/
public Person generatePerson(int index, long personSeed) {

Person person = null;
Person person = new Person(personSeed);

try {
int tryNumber = 0; // Number of tries to create these demographics
Random randomForDemographics = new Random(personSeed);

Map<String, Object> demoAttributes = randomDemographics(randomForDemographics);
Map<String, Object> demoAttributes = randomDemographics(person);
if (this.recordGroups != null) {
// Pick fixed demographics if a fixed demographics record file is used.
demoAttributes = pickFixedDemographics(index, random);
demoAttributes = pickFixedDemographics(index, person);
}

boolean patientMeetsCriteria;
Expand Down Expand Up @@ -509,7 +513,7 @@ public Person generatePerson(int index, long personSeed) {
// when we want to export this patient, but keep trying to produce one meeting criteria
if (!check.exportAnyway()) {
// rotate the seed so the next attempt gets a consistent but different one
personSeed = randomForDemographics.nextLong();
personSeed = person.randLong();
continue;
// skip the other stuff if the patient doesn't meet our goals
// note that this skips ahead to the while check
Expand All @@ -521,19 +525,19 @@ public Person generatePerson(int index, long personSeed) {

if (!isAlive) {
// rotate the seed so the next attempt gets a consistent but different one
personSeed = randomForDemographics.nextLong();
personSeed = person.randLong();

// if we've tried and failed > 10 times to generate someone over age 90
// and the options allow for ages as low as 85
// reduce the age to increase the likelihood of success
if (tryNumber > 10 && (int)person.attributes.get(TARGET_AGE) > 90
&& (!options.ageSpecified || options.minAge <= 85)) {
// pick a new target age between 85 and 90
int newTargetAge = randomForDemographics.nextInt(5) + 85;
int newTargetAge = person.randInt(5) + 85;
// the final age bracket is 85-110, but our patients rarely break 100
// so reducing a target age to 85-90 shouldn't affect numbers too much
demoAttributes.put(TARGET_AGE, newTargetAge);
long birthdate = birthdateFromTargetAge(newTargetAge, randomForDemographics);
long birthdate = birthdateFromTargetAge(newTargetAge, person);
demoAttributes.put(Person.BIRTHDATE, birthdate);
}
}
Expand Down Expand Up @@ -705,7 +709,7 @@ public void updatePerson(Person person) {
* @param random The random number generator to use.
* @return demographics
*/
public Map<String, Object> randomDemographics(Random random) {
public Map<String, Object> randomDemographics(RandomNumberGenerator random) {
Demographics city = location.randomCity(random);
Map<String, Object> demoAttributes = pickDemographics(random, city);
return demoAttributes;
Expand All @@ -722,11 +726,12 @@ private synchronized void writeToConsole(Person person, int index, long time, bo
// this is synchronized to ensure all lines for a single person are always printed
// consecutively
String deceased = isAlive ? "" : "DECEASED";
System.out.format("%d -- %s (%d y/o %s) %s, %s %s\n", index + 1,
System.out.format("%d -- %s (%d y/o %s) %s, %s %s (%d)\n", index + 1,
person.attributes.get(Person.NAME), person.ageInYears(time),
person.attributes.get(Person.GENDER),
person.attributes.get(Person.CITY), person.attributes.get(Person.STATE),
deceased);
deceased,
person.getCount());

if (this.logLevel.equals("detailed")) {
System.out.println("ATTRIBUTES");
Expand All @@ -750,7 +755,7 @@ private synchronized void writeToConsole(Person person, int index, long time, bo
* @param city The city to base the demographics off of.
* @return the person's picked demographics.
*/
private Map<String, Object> pickDemographics(Random random, Demographics city) {
private Map<String, Object> pickDemographics(RandomNumberGenerator random, Demographics city) {
// Output map of the generated demographc data.
Map<String, Object> demographicsOutput = new HashMap<>();

Expand Down Expand Up @@ -794,7 +799,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
double povertyRatio = city.povertyRatio(income);
demographicsOutput.put(Person.POVERTY_RATIO, povertyRatio);

double occupation = random.nextDouble();
double occupation = random.rand();
demographicsOutput.put(Person.OCCUPATION_LEVEL, occupation);

double sesScore = city.socioeconomicScore(incomeLevel, educationLevel, occupation);
Expand All @@ -809,7 +814,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
int targetAge;
if (options.ageSpecified) {
targetAge =
(int) (options.minAge + ((options.maxAge - options.minAge) * random.nextDouble()));
(int) (options.minAge + ((options.maxAge - options.minAge) * random.rand()));
} else {
targetAge = city.pickAge(random);
}
Expand All @@ -827,7 +832,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
* @param index The index to use.
* @param random Random object.
*/
private Map<String, Object> pickFixedDemographics(int index, Random random) {
private Map<String, Object> pickFixedDemographics(int index, RandomNumberGenerator random) {

// Get the first FixedRecord from the current RecordGroup
FixedRecordGroup recordGroup = this.recordGroups.get(index);
Expand Down Expand Up @@ -861,11 +866,11 @@ private Map<String, Object> pickFixedDemographics(int index, Random random) {
* @param random A random object.
* @return
*/
private long birthdateFromTargetAge(long targetAge, Random random) {
private long birthdateFromTargetAge(long targetAge, RandomNumberGenerator random) {
long earliestBirthdate = referenceTime - TimeUnit.DAYS.toMillis((targetAge + 1) * 365L + 1);
long latestBirthdate = referenceTime - TimeUnit.DAYS.toMillis(targetAge * 365L);
return
(long) (earliestBirthdate + ((latestBirthdate - earliestBirthdate) * random.nextDouble()));
(long) (earliestBirthdate + ((latestBirthdate - earliestBirthdate) * random.rand()));
}

/**
Expand Down Expand Up @@ -908,52 +913,10 @@ private Predicate<String> getModulePredicate() {
}

/**
* Returns a random double.
*/
public double rand() {
return random.nextDouble();
}

/**
* Returns a random boolean.
*/
public boolean randBoolean() {
return random.nextBoolean();
}

/**
* Returns a random integer.
*/
public int randInt() {
return random.nextInt();
}

/**
* Returns a random integer in the given bound.
* Get the seeded random number generator used by this Generator.
* @return the random number generator.
*/
public int randInt(int bound) {
return random.nextInt(bound);
public RandomNumberGenerator getRandomizer() {
return this.populationRandom;
}

/**
* Returns a double from a normal distribution.
*/
public double randGaussian() {
return random.nextGaussian();
}

/**
* Return a random long.
*/
public long randLong() {
return random.nextLong();
}

/**
* Return a random UUID.
*/
public UUID randUUID() {
return new UUID(randLong(), randLong());
}

}
6 changes: 3 additions & 3 deletions src/main/java/org/mitre/synthea/export/Exporter.java
Original file line number Diff line number Diff line change
Expand Up @@ -434,20 +434,20 @@ public static void runPostCompletionExports(Generator generator, ExporterRuntime

// Before we force bulk data to be off...
try {
FhirGroupExporterR4.exportAndSave(generator, generator.stop);
FhirGroupExporterR4.exportAndSave(generator.getRandomizer(), generator.stop);
} catch (Exception e) {
e.printStackTrace();
}

Config.set("exporter.fhir.bulk_data", "false");
try {
HospitalExporterR4.export(generator, generator.stop);
HospitalExporterR4.export(generator.getRandomizer(), generator.stop);
} catch (Exception e) {
e.printStackTrace();
}

try {
FhirPractitionerExporterR4.export(generator, generator.stop);
FhirPractitionerExporterR4.export(generator.getRandomizer(), generator.stop);
} catch (Exception e) {
e.printStackTrace();
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mitre/synthea/export/FhirDstu2.java
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ private static Entry basicInfo(Person person, Bundle bundle, long stopTime) {

String generatedBySynthea = "Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
+ " Person seed: " + person.seed
+ " Person seed: " + person.getSeed()
+ " Population seed: " + person.populationSeed;

patientResource.setText(new NarrativeDt(
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mitre/synthea/export/FhirR4.java
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ private static BundleEntryComponent basicInfo(Person person, Bundle bundle, long
String generatedBySynthea =
"Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
+ " Person seed: " + person.seed
+ " Person seed: " + person.getSeed()
+ " Population seed: " + person.populationSeed;

patientResource.setText(new Narrative().setStatus(NarrativeStatus.GENERATED)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mitre/synthea/export/FhirStu3.java
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ private static BundleEntryComponent basicInfo(Person person, Bundle bundle, long

String generatedBySynthea = "Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
+ " Person seed: " + person.seed
+ " Person seed: " + person.getSeed()
+ " Population seed: " + person.populationSeed;

patientResource.setText(new Narrative().setStatus(NarrativeStatus.GENERATED)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/mitre/synthea/export/JSONExporter.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public PersonSerializer(boolean includeModuleHistory) {
@Override
public JsonElement serialize(Person src, Type typeOfSrc, JsonSerializationContext context) {
JsonObject personOut = new JsonObject();
personOut.add("seed", new JsonPrimitive(src.seed));
personOut.add("seed", new JsonPrimitive(src.getSeed()));
personOut.add("lastUpdated", new JsonPrimitive(src.lastUpdated));
personOut.add("coverage", context.serialize(src.coverage));
JsonObject attributes = new JsonObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private Code resolveCode(@Nullable Code code) {
return null;
}
return code.valueSet != null
? RandomCodeGenerator.getCode(code.valueSet, person.seed, code)
? RandomCodeGenerator.getCode(code.valueSet, person.getSeed(), code)
: code;
}

Expand Down
Loading

0 comments on commit d90b4e3

Please sign in to comment.