Skip to content

Commit 7ea64c4

Browse files
committed
Prevent deleting certificate that is use
1 parent 88b46ef commit 7ea64c4

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

backend/internal/api/handler/certificates.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package handler
22

33
import (
4+
"database/sql"
45
"encoding/json"
56
"fmt"
67
"net/http"
@@ -10,6 +11,7 @@ import (
1011
"npm/internal/api/middleware"
1112
"npm/internal/api/schema"
1213
"npm/internal/entity/certificate"
14+
"npm/internal/entity/host"
1315
"npm/internal/jobqueue"
1416
"npm/internal/logger"
1517
)
@@ -141,11 +143,20 @@ func DeleteCertificate() func(http.ResponseWriter, *http.Request) {
141143
return
142144
}
143145

144-
cert, err := certificate.GetByID(certificateID)
145-
if err != nil {
146+
item, err := certificate.GetByID(certificateID)
147+
switch err {
148+
case sql.ErrNoRows:
149+
h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil)
150+
case nil:
151+
// Ensure that this upstream isn't in use by a host
152+
cnt := host.GetCertificateUseCount(certificateID)
153+
if cnt > 0 {
154+
h.ResultErrorJSON(w, r, http.StatusBadRequest, "Cannot delete certificate that is in use by at least 1 host", nil)
155+
return
156+
}
157+
h.ResultResponseJSON(w, r, http.StatusOK, item.Delete())
158+
default:
146159
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
147-
} else {
148-
h.ResultResponseJSON(w, r, http.StatusOK, cert.Delete())
149160
}
150161
}
151162
}

backend/internal/entity/host/methods.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ func GetUpstreamUseCount(upstreamID int) int {
213213
return totalRows
214214
}
215215

216+
// GetCertificateUseCount returns the number of hosts that are using
217+
// a certificate, and have not been deleted.
218+
func GetCertificateUseCount(certificateID int) int {
219+
db := database.GetInstance()
220+
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE certificate_id = ? AND is_deleted = ?", tableName)
221+
countRow := db.QueryRowx(query, certificateID, 0)
222+
var totalRows int
223+
queryErr := countRow.Scan(&totalRows)
224+
if queryErr != nil && queryErr != sql.ErrNoRows {
225+
logger.Debug("%s", query)
226+
return 0
227+
}
228+
return totalRows
229+
}
230+
216231
// AddPendingJobs is intended to be used at startup to add
217232
// anything pending to the JobQueue just once, based on
218233
// the database row status

0 commit comments

Comments
 (0)