-
Notifications
You must be signed in to change notification settings - Fork 3
This method doesn't work properly after migrating to Tensorlfow 2 #3
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
This method is bound to the _calibrate method working properly.
Emgraph/emgraph/models/EmbeddingModel.py
Lines 2137 to 2171 in 3926ad7
| def _predict_proba(self, X): | |
| """Predicts probabilities using the Platt scaling model (after calibration). | |
| Model must be calibrated beforehand with the ``calibrate`` method. | |
| :param X: Numpy array of triples to be evaluated. | |
| :type X: ndarray, shape [n, 3] | |
| :return: Probability of each triple to be true according to the Platt scaling calibration. | |
| :rtype: ndarray, shape [n, 3] | |
| """ | |
| if not self.is_calibrated: | |
| msg = "Model has not been calibrated. Please call `model.calibrate(...)` before predicting probabilities." | |
| logger.error(msg) | |
| raise RuntimeError(msg) | |
| # tf.reset_default_graph() | |
| self._load_model_from_trained_params() | |
| w = tf.Variable(self.calibration_parameters[0], dtype=tf.float32, trainable=False) | |
| b = tf.Variable(self.calibration_parameters[1], dtype=tf.float32, trainable=False) | |
| x_idx = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx) | |
| x_tf = tf.Variable(x_idx, dtype=tf.int32, trainable=False) | |
| e_s, e_p, e_o = self._lookup_embeddings(x_tf) | |
| scores = self._fn(e_s, e_p, e_o) | |
| logits = -(w * scores + b) | |
| probas = tf.sigmoid(logits) | |
| # with tf.Session(config=self.tf_config) as sess: | |
| # sess.run(tf.global_variables_initializer()) | |
| # return sess.run(probas) | |
| return probas |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed