Failing after multiple loops
Closed this issue · 1 comments
- BTB version: latest
- Python version: latest
- Operating System: Windows
Description
I'm working on using a tuned copulaGAN to generate synthetic data. I'm setting up the copulaGAN each loop with the new hyperparameters each time and evaluating it and using the aggregated score as the maximisation goal.
tuner = GCPTuner(Tunable({
'epochs': hp.IntHyperParam(min = 24, max = 300),
'batch_size' : hp.IntHyperParam(min = 1, max = 30, include_min = True, include_max = True),
}))
for _ in range(10):
proposal = tuner.propose(1)
print(proposal)
model = CopulaGAN(batch_size = proposal['batch_size'] * 10,
epochs = proposal['epochs'],
primary_key = primary_key_here,
field_distributions = dictionary_here)
model.fit(real_data)
synth_data = model.sample(10, max_retries = 100)
score = evaluate(synthetic_data = synth_data, real_data = real_data)
if score > best_score:
best_params = proposal
best_score = score
tuner.record(proposal, score)
print('Best score obtained: ', best_score)
print('Best parameters: ', best_params)
Certain bits have been excluded as I believe they're not necessary for this.
I'm running the loop and printing out almost after every line with the following being returned
{'epochs': 216, 'batch_size': 15} # print hps
<sdv.tabular.copulagan.CopulaGAN object at 0x7f4dd8dc09b0> # print model
<sdv.tabular.copulagan.CopulaGAN object at 0x7f4dd8dc09b0> # print fitted modell
{'cstest': 0.9034252238048042, 'kstest': 0.3842909885043684, 'logistic_detection': 0.9978902953586498, 'svc_detection': 0.981364275668073} # print unaggregated metrics
_**REPEAT FOR EACH LOOP**_
{'epochs': 204, 'batch_size': 13}
<sdv.tabular.copulagan.CopulaGAN object at 0x7f4dd0424630>
<sdv.tabular.copulagan.CopulaGAN object at 0x7f4dd0424630>
So the loop is running and then seems to break upon evaluating second loop. I've tested this before and it works with varying numbers of loops which made me think it was the hyperparameter inputs. My understanding is the batch_size
must be a multiple of ten which it is as I'm multiplying by 10 in the loop, not the tuner
intialising.
Traceback
AssertionError Traceback (most recent call last)
<ipython-input-10-ad6a0fad764a> in <module>()
47 print(model)
48 synth_data = model.sample(10, max_retries = 100)
---> 49 print(evaluate(synthetic_data = synth_data, real_data = real, aggregate = False))
50 score = evaluate(synthetic_data = synth_data, real_data = real).astype(float)
51 if score > best_score:
7 frames
/usr/local/lib/python3.6/dist-packages/sdv/evaluation.py in evaluate(synthetic_data, real_data, metadata, root_path, table_name, metrics, get_report, aggregate)
152 computed = {}
153 for metric in metrics:
--> 154 computed[metric] = METRICS[metric](synth, real, metadata, details=get_report)
155
156 if get_report:
/usr/local/lib/python3.6/dist-packages/sdv/evaluation.py in _kstest(synthetic, real, metadata, details)
94
95 def _kstest(synthetic, real, metadata=None, details=False):
---> 96 return _tabular_metric(KSTest(), synthetic, real, metadata, details)
97
98
/usr/local/lib/python3.6/dist-packages/sdv/evaluation.py in _tabular_metric(sdmetric, synthetic, real, metadata, details)
86 return list(metrics)
87
---> 88 return np.mean([metric.value for metric in metrics])
89
90
/usr/local/lib/python3.6/dist-packages/sdv/evaluation.py in <listcomp>(.0)
86 return list(metrics)
87
---> 88 return np.mean([metric.value for metric in metrics])
89
90
/usr/local/lib/python3.6/dist-packages/sdmetrics/statistical/univariate/base.py in metrics(self, metadata, real_tables, synthetic_tables)
47 real = real_tables[table_name]
48 synthetic = synthetic_tables[table_name]
---> 49 yield from self._compute(table_name, dtypes, real, synthetic)
50
51 def _compute(self, name, dtypes, real, synthetic):
/usr/local/lib/python3.6/dist-packages/sdmetrics/statistical/univariate/base.py in _compute(self, name, dtypes, real, synthetic)
66 goal=goal,
67 unit=unit,
---> 68 domain=domain
69 )
/usr/local/lib/python3.6/dist-packages/sdmetrics/report.py in __init__(self, name, value, tags, goal, unit, domain, description)
47 self.domain = domain
48 self.description = description
---> 49 self._validate()
50
51 def _validate(self):
/usr/local/lib/python3.6/dist-packages/sdmetrics/report.py in _validate(self)
51 def _validate(self):
52 assert isinstance(self.name, str)
---> 53 assert isinstance(self.value, float)
54 assert isinstance(self.tags, set)
55 assert isinstance(self.goal, Goal)
AssertionError: ```
Closing as moved to SDEV.