Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/cluster/cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <vector>

#include "cluster/cluster_defs.h"
#include "cluster/cluster_failover.h"
#include "commands/commander.h"
#include "common/io_util.h"
#include "fmt/format.h"
Expand Down Expand Up @@ -221,6 +222,9 @@ Status Cluster::SetClusterNodes(const std::string &nodes_str, int64_t version, b
// Clear migrated and imported slot info
migrated_slots_.clear();
imported_slots_.clear();
if (srv_->cluster_failover) {
srv_->cluster_failover->ResetFailoverState();
}

return Status::OK();
}
Expand Down Expand Up @@ -447,6 +451,12 @@ Status Cluster::GetClusterInfo(std::string *cluster_infos) {
std::string import_infos;
srv_->slot_import->GetImportInfo(&import_infos);
*cluster_infos += import_infos;

if (srv_->cluster_failover) {
std::string failover_info;
srv_->cluster_failover->GetFailoverInfo(&failover_info);
*cluster_infos += failover_info;
}
}

return Status::OK();
Expand Down Expand Up @@ -898,6 +908,10 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons

uint64_t flags = attributes->GenerateFlags(cmd_tokens, *srv_->GetConfig());

if (srv_->cluster_failover && srv_->cluster_failover->IsWriteForbidden() && (flags & redis::kCmdWrite)) {
return {Status::RedisTryAgain, "Failover in progress"};
}

if (myself_ && myself_ == slots_nodes_[slot]) {
// We use central controller to manage the topology of the cluster.
// Server can't change the topology directly, so we record the migrated slots
Expand Down Expand Up @@ -976,3 +990,51 @@ Status Cluster::Reset() {
unlink(srv_->GetConfig()->NodesFilePath().data());
return Status::OK();
}

StatusOr<std::pair<std::string, int>> Cluster::GetNodeIPPort(const std::string &node_id) {
auto it = nodes_.find(node_id);
if (it == nodes_.end()) {
return {Status::NotOK, "Node not found"};
}
return std::make_pair(it->second->host, it->second->port);
}

Status Cluster::OnTakeOver() {
info("[Failover] OnTakeOver received myself_: {}", myself_ ? myself_->id : "null");
if (!myself_) {
return {Status::NotOK, "Cluster is not initialized"};
}
if (myself_->role == kClusterMaster) {
info("[Failover] OnTakeOver myself_ is master, return");
return Status::OK();
}

std::string old_master_id = myself_->master_id;
if (old_master_id.empty()) {
info("[Failover] OnTakeOver no master to takeover, return");
return {Status::NotOK, "No master to takeover"};
}

for (int i = 0; i < kClusterSlots; i++) {
if (slots_nodes_[i] && slots_nodes_[i]->id == old_master_id) {
imported_slots_.insert(i);
}
}
info("[Failover] OnTakeOver Success ");
return Status::OK();
}

void Cluster::SetMySlotsMigrated(const std::string &dst_ip_port) {
// It is called by failover thread.
auto exclusivity = srv_->WorkExclusivityGuard();

for (int i = 0; i < kClusterSlots; i++) {
if (slots_nodes_[i] == myself_) {
migrated_slots_[i] = dst_ip_port;
}
}
}

bool Cluster::IsSlotImported(int slot) const {
return imported_slots_.count(slot) > 0;
}
5 changes: 5 additions & 0 deletions src/cluster/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class Cluster {
Status DumpClusterNodes(const std::string &file);
Status LoadClusterNodes(const std::string &file_path);
Status Reset();
Status OnTakeOver();

StatusOr<std::pair<std::string, int>> GetNodeIPPort(const std::string &node_id);
void SetMySlotsMigrated(const std::string &dst_ip_port);
bool IsSlotImported(int slot) const;

static bool SubCommandIsExecExclusive(const std::string &subcommand);

Expand Down
293 changes: 293 additions & 0 deletions src/cluster/cluster_failover.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#include "cluster_failover.h"

#include <unistd.h>

#include "cluster/cluster.h"
#include "common/io_util.h"
#include "common/time_util.h"
#include "logging.h"
#include "server/redis_reply.h"
#include "server/server.h"

ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) {
t_ = std::thread([this]() { loop(); });
}

ClusterFailover::~ClusterFailover() {
{
std::lock_guard<std::mutex> lock(mutex_);
stop_thread_ = true;
cv_.notify_all();
}
if (t_.joinable()) t_.join();
}

Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) {
std::lock_guard<std::mutex> lock(mutex_);
if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) {
return {Status::NotOK, "Failover is already in progress"};
}

if (srv_->IsSlave()) {
return {Status::NotOK, "Current node is a slave, can't failover"};
}

slave_node_id_ = std::move(slave_node_id);
timeout_ms_ = timeout_ms;
state_ = FailoverState::kStarted;
failover_job_triggered_ = true;
cv_.notify_all();
return Status::OK();
}

void ClusterFailover::loop() {
while (true) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this]() { return stop_thread_ || failover_job_triggered_; });

if (stop_thread_) return;

if (failover_job_triggered_) {
failover_job_triggered_ = false;
lock.unlock();
runFailoverProcess();
}
}
}

void ClusterFailover::runFailoverProcess() {
auto ip_port = srv_->cluster->GetNodeIPPort(slave_node_id_);
if (!ip_port.IsOK()) {
error("[Failover] slave node not found in cluster {}", slave_node_id_);
abortFailover("Slave node not found in cluster");
return;
}
node_ip_port_ = ip_port.GetValue().first + ":" + std::to_string(ip_port.GetValue().second);
node_ip_ = ip_port.GetValue().first;
node_port_ = ip_port.GetValue().second;
info("[Failover] slave node {} {} failover state: {}", slave_node_id_, node_ip_port_, static_cast<int>(state_.load()));
state_ = FailoverState::kCheck;

auto s = checkSlaveStatus();
if (!s.IsOK()) {
abortFailover(s.Msg());
return;
}

s = checkSlaveLag();
if (!s.IsOK()) {
abortFailover("Slave lag check failed: " + s.Msg());
return;
}

info("[Failover] slave node {} {} check slave status success, enter pause state", slave_node_id_, node_ip_port_);
start_time_ms_ = util::GetTimeStampMS();
// Enter Pause state (Stop writing)
state_ = FailoverState::kPause;
// Get current sequence
target_seq_ = srv_->storage->LatestSeqNumber();
info("[Failover] slave node {} {} target sequence {}", slave_node_id_, node_ip_port_, target_seq_);

state_ = FailoverState::kSyncWait;
s = waitReplicationSync();
if (!s.IsOK()) {
abortFailover(s.Msg());
return;
}
info("[Failover] slave node {} {} wait replication sync success, enter switch state, cost {} ms", slave_node_id_,
node_ip_port_, util::GetTimeStampMS() - start_time_ms_);

state_ = FailoverState::kSwitch;
s = sendTakeoverCmd();
if (!s.IsOK()) {
abortFailover(s.Msg());
return;
}

// Redirect slots
srv_->cluster->SetMySlotsMigrated(node_ip_port_);

state_ = FailoverState::kSuccess;
info("[Failover] success {} {}", slave_node_id_, node_ip_port_);
}

Status ClusterFailover::checkSlaveLag() {
auto start_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_);
if (!start_offset_status.IsOK()) {
return {Status::NotOK, "Failed to get slave offset: " + start_offset_status.Msg()};
}
uint64_t start_offset = *start_offset_status;
int64_t start_sampling_ms = util::GetTimeStampMS();

// Wait 3s or half of timeout, but at least a bit to measure speed
int64_t wait_time = std::max(100, std::min(3000, timeout_ms_ / 2));
std::this_thread::sleep_for(std::chrono::milliseconds(wait_time));

auto end_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_);
if (!end_offset_status.IsOK()) {
return {Status::NotOK, "Failed to get slave offset: " + end_offset_status.Msg()};
}
uint64_t end_offset = *end_offset_status;
int64_t end_sampling_ms = util::GetTimeStampMS();

double elapsed_sec = (end_sampling_ms - start_sampling_ms) / 1000.0;
if (elapsed_sec <= 0) elapsed_sec = 0.001;

uint64_t bytes = 0;
if (end_offset > start_offset) bytes = end_offset - start_offset;
double speed = bytes / elapsed_sec;

uint64_t master_seq = srv_->storage->LatestSeqNumber();
uint64_t lag = 0;
if (master_seq > end_offset) lag = master_seq - end_offset;

if (lag == 0) return Status::OK();

if (speed <= 0.1) { // Basically 0
return {Status::NotOK, fmt::format("Slave is not replicating (lag: {})", lag)};
}

double required_sec = lag / speed;
int64_t required_ms = static_cast<int64_t>(required_sec * 1000);

int64_t elapsed_total = end_sampling_ms - start_sampling_ms;
int64_t remaining = timeout_ms_ - elapsed_total;

if (required_ms > remaining) {
return {Status::NotOK, fmt::format("Estimated catchup time {}ms > remaining time {}ms (lag: {}, speed: {:.2f}/s)",
required_ms, remaining, lag, speed)};
}

info("[Failover] check: lag={}, speed={:.2f}/s, estimated_time={}ms, remaining={}ms", lag, speed, required_ms,
remaining);
return Status::OK();
}

Status ClusterFailover::checkSlaveStatus() {
// We could try to connect, but GetSlaveReplicationOffset checks connection.
auto offset = srv_->GetSlaveReplicationOffset(node_ip_port_);
if (!offset.IsOK()) {
error("[Failover] slave node {} {} not connected or not syncing", slave_node_id_, node_ip_port_);
return {Status::NotOK, "Slave not connected or not syncing"};
}
info("[Failover] slave node {} {} is connected and syncing offset {}", slave_node_id_, node_ip_port_, offset.Msg());
return Status::OK();
}

Status ClusterFailover::waitReplicationSync() {
while (true) {
if (util::GetTimeStampMS() - start_time_ms_ > static_cast<uint64_t>(timeout_ms_)) {
return {Status::NotOK, "Timeout waiting for replication sync"};
}

auto offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_);
if (!offset_status.IsOK()) {
return {Status::NotOK, "Failed to get slave offset: " + offset_status.Msg()};
}

if (*offset_status >= target_seq_) {
return Status::OK();
}

std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

Status ClusterFailover::sendTakeoverCmd() {
auto s = util::SockConnect(node_ip_, node_port_);
if (!s.IsOK()) {
return {Status::NotOK, "Failed to connect to slave: " + s.Msg()};
}
int fd = *s;

std::string pass = srv_->GetConfig()->requirepass;
if (!pass.empty()) {
std::string auth_cmd = redis::ArrayOfBulkStrings({"AUTH", pass});
auto s_auth = util::SockSend(fd, auth_cmd);
if (!s_auth.IsOK()) {
close(fd);
return {Status::NotOK, "Failed to send AUTH: " + s_auth.Msg()};
}
auto s_line = util::SockReadLine(fd);
if (!s_line.IsOK() || s_line.GetValue().substr(0, 3) != "+OK") {
close(fd);
return {Status::NotOK, "AUTH failed"};
}
Comment on lines +231 to +235
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential resource leak: if SockReadLine fails or returns an unexpected response, the socket fd is closed. However, if GetValue().substr() throws an exception (e.g., if the response is less than 3 characters), the socket will not be closed. Consider using RAII or ensuring close(fd) is called in all paths, including exception paths.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without resource leak risk, beacuse all path have close(fd).
if (!pass.empty()) {
std::string auth_cmd = redis::ArrayOfBulkStrings({"AUTH", pass});
auto s_auth = util::SockSend(fd, auth_cmd);
if (!s_auth.IsOK()) {
close(fd);
return {Status::NotOK, "Failed to send AUTH: " + s_auth.Msg()};
}
auto s_line = util::SockReadLine(fd);
if (!s_line.IsOK() || s_line.GetValue().substr(0, 3) != "+OK") {
close(fd);
return {Status::NotOK, "AUTH failed"};
}
}

std::string cmd = redis::ArrayOfBulkStrings({"CLUSTERX", "TAKEOVER"});
auto s_send = util::SockSend(fd, cmd);
if (!s_send.IsOK()) {
close(fd);
return {Status::NotOK, "Failed to send TAKEOVER: " + s_send.Msg()};
}

auto s_resp = util::SockReadLine(fd);
close(fd);

}

std::string cmd = redis::ArrayOfBulkStrings({"CLUSTERX", "TAKEOVER"});
auto s_send = util::SockSend(fd, cmd);
if (!s_send.IsOK()) {
close(fd);
return {Status::NotOK, "Failed to send TAKEOVER: " + s_send.Msg()};
}

auto s_resp = util::SockReadLine(fd);
close(fd);

if (!s_resp.IsOK()) {
return {Status::NotOK, "Failed to read TAKEOVER response: " + s_resp.Msg()};
}

if (s_resp.GetValue().substr(0, 3) != "+OK") {
return {Status::NotOK, "TAKEOVER failed: " + s_resp.GetValue()};
}
Comment on lines +252 to +254
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential resource leak: if SockReadLine fails or returns an unexpected response, the socket fd is closed. However, if GetValue().substr() throws an exception (e.g., if the response is less than 3 characters), the socket will not be closed. Consider using RAII or ensuring close(fd) is called in all paths, including exception paths.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto s_resp = util::SockReadLine(fd);
close(fd);

if (!s_resp.IsOK()) {
return {Status::NotOK, "Failed to read TAKEOVER response: " + s_resp.Msg()};
}

if (s_resp.GetValue().substr(0, 3) != "+OK") {
return {Status::NotOK, "TAKEOVER failed: " + s_resp.GetValue()};
}

The fd allways closed before any if condition.


return Status::OK();
}

void ClusterFailover::abortFailover(const std::string &reason) {
error("[Failover] node {} {} failover failed: {}", slave_node_id_, node_ip_port_, reason);
state_ = FailoverState::kFailed;
}

void ClusterFailover::GetFailoverInfo(std::string *info) {
*info = "cluster_failover_state:";
switch (state_.load()) {
case FailoverState::kNone:
*info += "none";
break;
case FailoverState::kStarted:
*info += "started";
break;
case FailoverState::kCheck:
*info += "check_slave";
break;
case FailoverState::kPause:
*info += "pause_write";
break;
case FailoverState::kSyncWait:
*info += "wait_sync";
break;
case FailoverState::kSwitch:
*info += "switching";
break;
case FailoverState::kSuccess:
*info += "success";
break;
case FailoverState::kFailed:
*info += "failed";
break;
}
*info += "\r\n";
}
Loading
Loading