diff --git a/cluster/client.go b/cluster/client.go index c7be108f..1c7b58d3 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -239,7 +239,7 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, creds *proto.C // Request performs an ExecuteQuery on a remote node. If creds is nil, then // no credential information will be included in the ExecuteQuery request to the // remote node. -func (c *Client) Request(r *command.ExecuteQueryRequest, nodeAddr string, creds *proto.Credentials, timeout time.Duration, retries int) ([]*command.ExecuteQueryResponse, uint64, error) { +func (c *Client) Request(r *command.ExecuteQueryRequest, nodeAddr string, creds *proto.Credentials, timeout time.Duration, retries int) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { command := &proto.Command{ Type: proto.Command_COMMAND_TYPE_REQUEST, Request: &proto.Command_ExecuteQueryRequest{ @@ -250,19 +250,19 @@ func (c *Client) Request(r *command.ExecuteQueryRequest, nodeAddr string, creds p, nr, err := c.retry(command, nodeAddr, timeout, retries) stats.Add(numClientRequestRetries, int64(nr)) if err != nil { - return nil, 0, err + return nil, 0, 0, err } a := &proto.CommandRequestResponse{} err = pb.Unmarshal(p, a) if err != nil { - return nil, 0, err + return nil, 0, 0, err } if a.Error != "" { - return nil, 0, errors.New(a.Error) + return nil, 0, 0, errors.New(a.Error) } - return a.Response, a.RaftIndex, nil + return a.Response, a.NumRW, a.RaftIndex, nil } // Backup retrieves a backup from a remote node and writes to the io.Writer. diff --git a/cluster/client_test.go b/cluster/client_test.go index f41a7016..1b793282 100644 --- a/cluster/client_test.go +++ b/cluster/client_test.go @@ -210,7 +210,7 @@ func Test_ClientRequest(t *testing.T) { defer srv.Close() c := NewClient(&simpleDialer{}, 0) - _, idx, err := c.Request(executeQueryRequestFromString("SELECT * FROM foo"), + _, _, idx, err := c.Request(executeQueryRequestFromString("SELECT * FROM foo"), srv.Addr(), nil, time.Second, defaultMaxRetries) if err != nil { t.Fatal(err) diff --git a/cluster/proto/message.pb.go b/cluster/proto/message.pb.go index 51d8e363..32513ea6 100644 --- a/cluster/proto/message.pb.go +++ b/cluster/proto/message.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.9 -// protoc v3.6.1 +// protoc-gen-go v1.36.10 +// protoc v3.21.12 // source: message.proto package proto @@ -583,6 +583,7 @@ type CommandRequestResponse struct { Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` Response []*proto.ExecuteQueryResponse `protobuf:"bytes,2,rep,name=response,proto3" json:"response,omitempty"` RaftIndex uint64 `protobuf:"varint,3,opt,name=raftIndex,proto3" json:"raftIndex,omitempty"` + NumRW uint64 `protobuf:"varint,4,opt,name=numRW,proto3" json:"numRW,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -638,6 +639,13 @@ func (x *CommandRequestResponse) GetRaftIndex() uint64 { return 0 } +func (x *CommandRequestResponse) GetNumRW() uint64 { + if x != nil { + return x.NumRW + } + return 0 +} + type CommandBackupResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` @@ -1110,11 +1118,12 @@ const file_message_proto_rawDesc = "" + "\x14CommandQueryResponse\x12\x14\n" + "\x05error\x18\x01 \x01(\tR\x05error\x12&\n" + "\x04rows\x18\x02 \x03(\v2\x12.command.QueryRowsR\x04rows\x12\x1c\n" + - "\traftIndex\x18\x03 \x01(\x04R\traftIndex\"\x87\x01\n" + + "\traftIndex\x18\x03 \x01(\x04R\traftIndex\"\x9d\x01\n" + "\x16CommandRequestResponse\x12\x14\n" + "\x05error\x18\x01 \x01(\tR\x05error\x129\n" + "\bresponse\x18\x02 \x03(\v2\x1d.command.ExecuteQueryResponseR\bresponse\x12\x1c\n" + - "\traftIndex\x18\x03 \x01(\x04R\traftIndex\"A\n" + + "\traftIndex\x18\x03 \x01(\x04R\traftIndex\x12\x14\n" + + "\x05numRW\x18\x04 \x01(\x04R\x05numRW\"A\n" + "\x15CommandBackupResponse\x12\x14\n" + "\x05error\x18\x01 \x01(\tR\x05error\x12\x12\n" + "\x04data\x18\x02 \x01(\fR\x04data\"+\n" + diff --git a/cluster/proto/message.proto b/cluster/proto/message.proto index d67a81e7..f4ccd967 100644 --- a/cluster/proto/message.proto +++ b/cluster/proto/message.proto @@ -68,6 +68,7 @@ message CommandRequestResponse { string error = 1; repeated command.ExecuteQueryResponse response = 2; uint64 raftIndex = 3; + uint64 numRW = 4; } message CommandBackupResponse { diff --git a/cluster/service.go b/cluster/service.go index f054fae2..b7b562db 100644 --- a/cluster/service.go +++ b/cluster/service.go @@ -101,7 +101,7 @@ type Database interface { Query(qr *command.QueryRequest) ([]*command.QueryRows, command.ConsistencyLevel, uint64, error) // Request processes a request that can both executes and queries. - Request(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) + Request(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) // Backup writes a backup of the database to the writer. Backup(br *command.BackupRequest, dst io.Writer) error @@ -387,12 +387,13 @@ func (s *Service) handleConn(conn net.Conn) { } else if !s.checkCommandPermAll(c, auth.PermQuery, auth.PermExecute) { resp.Error = "unauthorized" } else { - res, idx, err := s.db.Request(rr) + res, numRW, idx, err := s.db.Request(rr) if err != nil { resp.Error = err.Error() } else { resp.Response = make([]*command.ExecuteQueryResponse, len(res)) copy(resp.Response, res) + resp.NumRW = numRW resp.RaftIndex = idx } } diff --git a/cluster/service_test.go b/cluster/service_test.go index 470ad126..0f620529 100644 --- a/cluster/service_test.go +++ b/cluster/service_test.go @@ -564,7 +564,7 @@ func mustNewMockTLSTransport() *mockTransport { type mockDatabase struct { executeFn func(er *command.ExecuteRequest) ([]*command.ExecuteQueryResponse, uint64, error) queryFn func(qr *command.QueryRequest) ([]*command.QueryRows, uint64, error) - requestFn func(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) + requestFn func(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) backupFn func(br *command.BackupRequest, dst io.Writer) error loadFn func(lr *command.LoadRequest) error } @@ -578,9 +578,9 @@ func (m *mockDatabase) Query(qr *command.QueryRequest) ([]*command.QueryRows, co return rows, command.ConsistencyLevel_NONE, idx, err } -func (m *mockDatabase) Request(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) { +func (m *mockDatabase) Request(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { if m.requestFn == nil { - return []*command.ExecuteQueryResponse{}, 0, nil + return []*command.ExecuteQueryResponse{}, 0, 0, nil } return m.requestFn(rr) } diff --git a/http/service.go b/http/service.go index 1207a11f..6454b903 100644 --- a/http/service.go +++ b/http/service.go @@ -60,7 +60,7 @@ type Database interface { // Request processes a slice of requests, each of which can be either // an Execute or Query request. - Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryResponse, uint64, error) + Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryResponse, uint64, uint64, error) // Load loads a SQLite file into the system via Raft consensus. Load(lr *proto.LoadRequest) error @@ -123,7 +123,7 @@ type Cluster interface { Query(qr *proto.QueryRequest, nodeAddr string, creds *clstrPB.Credentials, timeout time.Duration) ([]*proto.QueryRows, uint64, error) // Request performs an ExecuteQuery Request on a remote node. - Request(eqr *proto.ExecuteQueryRequest, nodeAddr string, creds *clstrPB.Credentials, timeout time.Duration, retries int) ([]*proto.ExecuteQueryResponse, uint64, error) + Request(eqr *proto.ExecuteQueryRequest, nodeAddr string, creds *clstrPB.Credentials, timeout time.Duration, retries int) ([]*proto.ExecuteQueryResponse, uint64, uint64, error) // Backup retrieves a backup from a remote node and writes to the io.Writer. Backup(br *proto.BackupRequest, nodeAddr string, creds *clstrPB.Credentials, timeout time.Duration, w io.Writer) error @@ -1534,7 +1534,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request, qp Query FreshnessStrict: qp.FreshnessStrict(), } - results, raftIndex, resultsErr := s.store.Request(eqr) + results, _, raftIndex, resultsErr := s.store.Request(eqr) if resultsErr != nil && resultsErr == store.ErrNotLeader { if s.DoRedirect(w, r, qp) { return @@ -1552,7 +1552,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request, qp Query } w.Header().Add(ServedByHTTPHeader, addr) - results, raftIndex, resultsErr = s.cluster.Request(eqr, addr, makeCredentials(r), + results, _, raftIndex, resultsErr = s.cluster.Request(eqr, addr, makeCredentials(r), qp.Timeout(defaultTimeout), qp.Retries(0)) if resultsErr != nil { stats.Add(numRemoteRequestsFailed, 1) diff --git a/http/service_test.go b/http/service_test.go index 76e1981e..393a32ff 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -1402,14 +1402,14 @@ func Test_ForwardingRedirectExecuteQuery(t *testing.T) { m := &MockStore{ leaderAddr: "foo:1234", } - m.requestFn = func(er *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) { - return nil, 0, store.ErrNotLeader + m.requestFn = func(er *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { + return nil, 0, 0, store.ErrNotLeader } c := &mockClusterService{ apiAddr: "https://bar:5678", } - c.requestFn = func(er *command.ExecuteQueryRequest, addr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, uint64, error) { + c.requestFn = func(er *command.ExecuteQueryRequest, addr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { resp := &command.ExecuteQueryResponse{ Result: &command.ExecuteQueryResponse_E{ E: &command.ExecuteResult{ @@ -1418,7 +1418,7 @@ func Test_ForwardingRedirectExecuteQuery(t *testing.T) { }, }, } - return []*command.ExecuteQueryResponse{resp}, 0, nil + return []*command.ExecuteQueryResponse{resp}, 0, 0, nil } s := New("127.0.0.1:0", m, c, nil) @@ -1592,7 +1592,7 @@ func Test_DBTimeoutQueryParam(t *testing.T) { type MockStore struct { executeFn func(er *command.ExecuteRequest) ([]*command.ExecuteQueryResponse, uint64, error) queryFn func(qr *command.QueryRequest) ([]*command.QueryRows, uint64, error) - requestFn func(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) + requestFn func(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) backupFn func(br *command.BackupRequest, dst io.Writer) error loadFn func(lr *command.LoadRequest) error snapshotFn func(n uint64) error @@ -1618,11 +1618,11 @@ func (m *MockStore) Query(qr *command.QueryRequest) ([]*command.QueryRows, comma return nil, command.ConsistencyLevel_NONE, 0, nil } -func (m *MockStore) Request(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, error) { +func (m *MockStore) Request(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { if m.requestFn != nil { return m.requestFn(eqr) } - return nil, 0, nil + return nil, 0, 0, nil } func (m *MockStore) Join(jr *command.JoinRequest) error { @@ -1701,7 +1701,7 @@ type mockClusterService struct { apiAddr string executeFn func(er *command.ExecuteRequest, addr string, t time.Duration) ([]*command.ExecuteQueryResponse, uint64, error) queryFn func(qr *command.QueryRequest, addr string, t time.Duration) ([]*command.QueryRows, uint64, error) - requestFn func(eqr *command.ExecuteQueryRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, uint64, error) + requestFn func(eqr *command.ExecuteQueryRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, uint64, uint64, error) backupFn func(br *command.BackupRequest, addr string, t time.Duration, w io.Writer) error loadFn func(lr *command.LoadRequest, addr string, t time.Duration) error removeNodeFn func(rn *command.RemoveNodeRequest, nodeAddr string, t time.Duration) error @@ -1728,11 +1728,11 @@ func (m *mockClusterService) Query(qr *command.QueryRequest, addr string, creds return nil, 0, nil } -func (m *mockClusterService) Request(eqr *command.ExecuteQueryRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration, r int) ([]*command.ExecuteQueryResponse, uint64, error) { +func (m *mockClusterService) Request(eqr *command.ExecuteQueryRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration, r int) ([]*command.ExecuteQueryResponse, uint64, uint64, error) { if m.requestFn != nil { return m.requestFn(eqr, nodeAddr, timeout) } - return nil, 0, nil + return nil, 0, 0, nil } func (m *mockClusterService) Backup(br *command.BackupRequest, addr string, creds *cluster.Credentials, t time.Duration, w io.Writer) error { diff --git a/store/store.go b/store/store.go index e795059b..47c9dc15 100644 --- a/store/store.go +++ b/store/store.go @@ -1433,14 +1433,14 @@ func (s *Store) VerifyLeader() (retErr error) { } // Request processes a request that may contain both Executes and Queries. -func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryResponse, uint64, error) { +func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryResponse, uint64, uint64, error) { p := (*PragmaCheckRequest)(eqr.Request) if err := p.Check(); err != nil { - return nil, 0, err + return nil, 0, 0, err } if !s.open.Is() { - return nil, 0, ErrNotOpen + return nil, 0, 0, ErrNotOpen } nRW, nRO := s.RORWCount(eqr) isLeader := s.raft.State() == raft.Leader @@ -1454,7 +1454,7 @@ func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryRe eqr.Level = proto.ConsistencyLevel_STRONG s.numLRUpgraded.Add(1) } else { - return nil, 0, err + return nil, 0, 0, err } } } @@ -1473,29 +1473,29 @@ func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryRe } if eqr.Level == proto.ConsistencyLevel_NONE && s.isStaleRead(eqr.Freshness, eqr.FreshnessStrict) { - return nil, 0, ErrStaleRead + return nil, 0, 0, ErrStaleRead } else if eqr.Level == proto.ConsistencyLevel_WEAK { if !isLeader { - return nil, 0, ErrNotLeader + return nil, 0, 0, ErrNotLeader } } qr, err := s.db.Query(eqr.Request, eqr.Timings) - return convertFn(qr), 0, err + return convertFn(qr), uint64(nRW), 0, err } // At least one write in the request, or STRONG consistency requested, so // we need to go through consensus. Check that we can do that. if !isLeader { - return nil, 0, ErrNotLeader + return nil, 0, 0, ErrNotLeader } if !s.Ready() { - return nil, 0, ErrNotReady + return nil, 0, 0, ErrNotReady } // Send the request through consensus. b, compressed, err := s.tryCompress(eqr) if err != nil { - return nil, 0, err + return nil, 0, 0, err } c := &proto.Command{ Type: proto.Command_COMMAND_TYPE_EXECUTE_QUERY, @@ -1504,15 +1504,15 @@ func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryRe } b, err = command.Marshal(c) if err != nil { - return nil, 0, err + return nil, 0, 0, err } af := s.raft.Apply(b, s.ApplyTimeout) if af.Error() != nil { if af.Error() == raft.ErrNotLeader { - return nil, 0, ErrNotLeader + return nil, 0, 0, ErrNotLeader } - return nil, 0, af.Error() + return nil, 0, 0, af.Error() } r := af.Response().(*fsmExecuteQueryResponse) @@ -1522,7 +1522,7 @@ func (s *Store) Request(eqr *proto.ExecuteQueryRequest) ([]*proto.ExecuteQueryRe s.strongReadTerm.Store(readTerm) } - return r.results, af.Index(), r.error + return r.results, uint64(nRW), af.Index(), r.error } // Backup writes a consistent snapshot of the underlying database to dst. This diff --git a/store/store_multi_test.go b/store/store_multi_test.go index b301cd94..d52b788f 100644 --- a/store/store_multi_test.go +++ b/store/store_multi_test.go @@ -314,7 +314,7 @@ func Test_MultiNodeSimple(t *testing.T) { // Write another row using Request rr := executeQueryRequestFromString("INSERT INTO foo(id, name) VALUES(2, 'fiona')", proto.ConsistencyLevel_STRONG, false, false) - _, _, err = s0.Request(rr) + _, _, _, err = s0.Request(rr) if err != nil { t.Fatalf("failed to execute on single node: %s", err.Error()) } @@ -1418,7 +1418,7 @@ func Test_MultiNodeExecuteQueryFreshness(t *testing.T) { eqr := executeQueryRequestFromString("SELECT * FROM foo", proto.ConsistencyLevel_NONE, false, false) eqr.Freshness = mustParseDuration("1ns").Nanoseconds() - _, _, err = s1.Request(eqr) + _, _, _, err = s1.Request(eqr) if err == nil { t.Fatalf("freshness violating request didn't return an error") } @@ -1426,7 +1426,7 @@ func Test_MultiNodeExecuteQueryFreshness(t *testing.T) { t.Fatalf("freshness violating request returned wrong error: %s", err.Error()) } eqr.Freshness = 0 - eqresp, _, err := s1.Request(eqr) + eqresp, _, _, err := s1.Request(eqr) if err != nil { t.Fatalf("inactive freshness violating request returned an error") } diff --git a/store/store_test.go b/store/store_test.go index 68f6a74b..5c2dd82e 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1044,7 +1044,7 @@ func Test_SingleNodeRequest_Linearizable(t *testing.T) { // Perform the first linearizable request, which should be upgraded to a strong query. eqr := executeQueryRequestFromString("SELECT * FROM foo", proto.ConsistencyLevel_LINEARIZABLE, false, false) - r, _, err := s.Request(eqr) + r, _, _, err := s.Request(eqr) if err != nil { t.Fatalf("failed to perform linearizable request on single node: %s", err.Error()) } @@ -1062,7 +1062,7 @@ func Test_SingleNodeRequest_Linearizable(t *testing.T) { } // Perform the second linearizable request, which should not be upgraded to a strong query. - r, _, err = s.Request(eqr) + r, _, _, err = s.Request(eqr) if err != nil { t.Fatalf("failed to perform linearizable request on single node: %s", err.Error()) } @@ -1211,7 +1211,7 @@ func Test_SingleNodeRequest(t *testing.T) { for _, tt := range tests { eqr := executeQueryRequestFromStrings(tt.stmts, proto.ConsistencyLevel_WEAK, false, false) - r, _, err := s.Request(eqr) + r, _, _, err := s.Request(eqr) if err != nil { t.Fatalf("failed to execute request on single node: %s", err.Error()) } @@ -1291,7 +1291,7 @@ func Test_SingleNodeRequestTx(t *testing.T) { for _, tt := range tests { eqr := executeQueryRequestFromStrings(tt.stmts, proto.ConsistencyLevel_WEAK, false, tt.tx) - r, _, err := s.Request(eqr) + r, _, _, err := s.Request(eqr) if err != nil { t.Fatalf("failed to execute request on single node: %s", err.Error()) } @@ -1423,7 +1423,7 @@ func Test_SingleNodeRequestParameters(t *testing.T) { } for _, tt := range tests { - r, _, err := s.Request(tt.request) + r, _, _, err := s.Request(tt.request) if err != nil { t.Fatalf("failed to execute request on single node: %s", err.Error()) } @@ -1590,7 +1590,7 @@ func Test_SingleNodeExecuteQueryFreshness(t *testing.T) { rr := executeQueryRequestFromString("SELECT * FROM foo", proto.ConsistencyLevel_NONE, false, false) rr.Freshness = mustParseDuration("1ns").Nanoseconds() - eqr, _, err := s0.Request(rr) + eqr, _, _, err := s0.Request(rr) if err != nil { t.Fatalf("failed to query leader node: %s", err.Error()) } @@ -3247,7 +3247,7 @@ func Test_StoreRequestRaftIndex(t *testing.T) { `CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)`, }, proto.ConsistencyLevel_STRONG, false, false) - _, raftIndex, err := s.Request(writeReq) + _, _, raftIndex, err := s.Request(writeReq) if err != nil { t.Fatalf("failed to execute write request: %s", err.Error()) } @@ -3260,7 +3260,7 @@ func Test_StoreRequestRaftIndex(t *testing.T) { `SELECT * FROM foo`, }, proto.ConsistencyLevel_NONE, false, false) - _, readIndex, err := s.Request(readReq) + _, _, readIndex, err := s.Request(readReq) if err != nil { t.Fatalf("failed to execute read request: %s", err.Error()) } @@ -3273,7 +3273,7 @@ func Test_StoreRequestRaftIndex(t *testing.T) { `INSERT INTO foo(id, name) VALUES(1, "test")`, }, proto.ConsistencyLevel_STRONG, false, false) - _, raftIndex2, err := s.Request(writeReq2) + _, _, raftIndex2, err := s.Request(writeReq2) if err != nil { t.Fatalf("failed to execute second write request: %s", err.Error()) } @@ -3286,7 +3286,7 @@ func Test_StoreRequestRaftIndex(t *testing.T) { `SELECT * FROM foo`, }, proto.ConsistencyLevel_STRONG, false, false) - _, strongReadIndex, err := s.Request(strongReadReq) + _, _, strongReadIndex, err := s.Request(strongReadReq) if err != nil { t.Fatalf("failed to execute strong read request: %s", err.Error()) } @@ -3299,6 +3299,86 @@ func Test_StoreRequestRaftIndex(t *testing.T) { } } +// Test_StoreRequestRWCount tests that Store.Request returns the correct number of RW statements +func Test_StoreRequestRWCount(t *testing.T) { + s, ln := mustNewStore(t) + defer ln.Close() + + if err := s.Open(); err != nil { + t.Fatalf("failed to open single-node store: %s", err.Error()) + } + defer s.Close(true) + if err := s.Bootstrap(NewServer(s.ID(), s.Addr(), true)); err != nil { + t.Fatalf("failed to bootstrap single-node store: %s", err.Error()) + } + if _, err := s.WaitForLeader(10 * time.Second); err != nil { + t.Fatalf("Error waiting for leader: %s", err) + } + + // Create a table first + createReq := executeQueryRequestFromString(`CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)`, proto.ConsistencyLevel_STRONG, false, false) + _, nRWCreate, _, err := s.Request(createReq) + if err != nil { + t.Fatalf("failed to create table: %s", err.Error()) + } + if nRWCreate != 1 { + t.Fatalf("expected nRW=1 for CREATE TABLE statement, got %d", nRWCreate) + } + + // Test 1: Single write statement should return nRW=1 + writeReq := executeQueryRequestFromStrings([]string{ + `INSERT INTO foo(id, name) VALUES(1, "fiona")`, + }, proto.ConsistencyLevel_STRONG, false, false) + _, nRW, _, err := s.Request(writeReq) + if err != nil { + t.Fatalf("failed to execute write request: %s", err.Error()) + } + if nRW != 1 { + t.Fatalf("expected nRW=1 for single write statement, got %d", nRW) + } + + // Test 2: Multiple write statements should return correct nRW count + multiWriteReq := executeQueryRequestFromStrings([]string{ + `INSERT INTO foo(id, name) VALUES(2, "alice")`, + `INSERT INTO foo(id, name) VALUES(3, "bob")`, + `UPDATE foo SET name="charlie" WHERE id=1`, + }, proto.ConsistencyLevel_STRONG, false, false) + _, nRW, _, err = s.Request(multiWriteReq) + if err != nil { + t.Fatalf("failed to execute multi-write request: %s", err.Error()) + } + if nRW != 3 { + t.Fatalf("expected nRW=3 for three write statements, got %d", nRW) + } + + // Test 3: Read-only statement should return nRW=0 + readReq := executeQueryRequestFromStrings([]string{ + `SELECT * FROM foo`, + }, proto.ConsistencyLevel_NONE, false, false) + _, nRW, _, err = s.Request(readReq) + if err != nil { + t.Fatalf("failed to execute read request: %s", err.Error()) + } + if nRW != 0 { + t.Fatalf("expected nRW=0 for read-only statement, got %d", nRW) + } + + // Test 4: Mixed read-write statements should return correct nRW count + mixedReq := executeQueryRequestFromStrings([]string{ + `SELECT * FROM foo`, + `INSERT INTO foo(id, name) VALUES(4, "diana")`, + `SELECT * FROM foo WHERE id=4`, + `DELETE FROM foo WHERE id=1`, + }, proto.ConsistencyLevel_STRONG, false, false) + _, nRW, _, err = s.Request(mixedReq) + if err != nil { + t.Fatalf("failed to execute mixed request: %s", err.Error()) + } + if nRW != 2 { + t.Fatalf("expected nRW=2 for two write statements in mixed request, got %d", nRW) + } +} + // Test_StoreQueryRaftIndex tests that Store.Query returns the correct Raft index func Test_StoreQueryRaftIndex(t *testing.T) { s, ln := mustNewStore(t) diff --git a/system_test/request_forwarding_test.go b/system_test/request_forwarding_test.go index 392980aa..3801707f 100644 --- a/system_test/request_forwarding_test.go +++ b/system_test/request_forwarding_test.go @@ -80,7 +80,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"columns":["id","name"],"types":["integer","text"],"values":[[1,"fiona"]]}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, _, err := node.Store.Request(executeQueryRequestFromString(`SELECT * FROM foo`)) + results, _, _, err := node.Store.Request(executeQueryRequestFromString(`SELECT * FROM foo`)) if err != nil { t.Fatalf("failed to request on local: %s", err.Error()) } @@ -94,7 +94,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"columns":["id","name"],"types":["integer","text"],"values":[[1,"fiona"]]}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, idx, err = client.Request(executeQueryRequestFromString(`SELECT * FROM foo`), leaderAddr, NO_CREDS, shortWait, 0) + results, _, idx, err = client.Request(executeQueryRequestFromString(`SELECT * FROM foo`), leaderAddr, NO_CREDS, shortWait, 0) if err != nil { t.Fatalf("failed to query via remote: %s", err.Error()) } @@ -113,7 +113,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"columns":["id","name"],"types":["integer","text"],"values":[[1,"fiona"]]}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, _, err = node.Store.Request(executeQueryRequestFromString(`SELECT * FROM bar`)) + results, _, _, err = node.Store.Request(executeQueryRequestFromString(`SELECT * FROM bar`)) if err != nil { t.Fatalf("failed to request on local: %s", err.Error()) } @@ -127,7 +127,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"columns":["id","name"],"types":["integer","text"],"values":[[1,"fiona"]]}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, idx, err = client.Request(executeQueryRequestFromString(`SELECT * FROM bar`), leaderAddr, NO_CREDS, shortWait, noRetries) + results, _, idx, err = client.Request(executeQueryRequestFromString(`SELECT * FROM bar`), leaderAddr, NO_CREDS, shortWait, noRetries) if err != nil { t.Fatalf("failed to query via remote: %s", err.Error()) } @@ -146,7 +146,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"error":"no such table: qux"}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, _, err = node.Store.Request(executeQueryRequestFromString(`SELECT * FROM qux`)) + results, _, _, err = node.Store.Request(executeQueryRequestFromString(`SELECT * FROM qux`)) if err != nil { t.Fatalf("failed to request on local: %s", err.Error()) } @@ -163,7 +163,7 @@ func Test_StoreClientSideBySide(t *testing.T) { if exp, got := `[{"error":"no such table: qux"}]`, asJSON(rows); exp != got { t.Fatalf("unexpected results, exp %s, got %s", exp, got) } - results, _, err = client.Request(executeQueryRequestFromString(`SELECT * FROM qux`), leaderAddr, NO_CREDS, shortWait, noRetries) + results, _, _, err = client.Request(executeQueryRequestFromString(`SELECT * FROM qux`), leaderAddr, NO_CREDS, shortWait, noRetries) if err != nil { t.Fatalf("failed to query via remote: %s", err.Error()) }