Source file
src/net/rpc/server.go
Documentation: net/rpc
1
2
3
4
5
127 package rpc
128
129 import (
130 "bufio"
131 "encoding/gob"
132 "errors"
133 "go/token"
134 "io"
135 "log"
136 "net"
137 "net/http"
138 "reflect"
139 "strings"
140 "sync"
141 )
142
143 const (
144
145 DefaultRPCPath = "/_goRPC_"
146 DefaultDebugPath = "/debug/rpc"
147 )
148
149
150
151 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
152
153 type methodType struct {
154 sync.Mutex
155 method reflect.Method
156 ArgType reflect.Type
157 ReplyType reflect.Type
158 numCalls uint
159 }
160
161 type service struct {
162 name string
163 rcvr reflect.Value
164 typ reflect.Type
165 method map[string]*methodType
166 }
167
168
169
170
171 type Request struct {
172 ServiceMethod string
173 Seq uint64
174 next *Request
175 }
176
177
178
179
180 type Response struct {
181 ServiceMethod string
182 Seq uint64
183 Error string
184 next *Response
185 }
186
187
188 type Server struct {
189 serviceMap sync.Map
190 reqLock sync.Mutex
191 freeReq *Request
192 respLock sync.Mutex
193 freeResp *Response
194 }
195
196
197 func NewServer() *Server {
198 return &Server{}
199 }
200
201
202 var DefaultServer = NewServer()
203
204
205 func isExportedOrBuiltinType(t reflect.Type) bool {
206 for t.Kind() == reflect.Pointer {
207 t = t.Elem()
208 }
209
210
211 return token.IsExported(t.Name()) || t.PkgPath() == ""
212 }
213
214
215
216
217
218
219
220
221
222
223
224
225 func (server *Server) Register(rcvr any) error {
226 return server.register(rcvr, "", false)
227 }
228
229
230
231 func (server *Server) RegisterName(name string, rcvr any) error {
232 return server.register(rcvr, name, true)
233 }
234
235
236
237 const logRegisterError = false
238
239 func (server *Server) register(rcvr any, name string, useName bool) error {
240 s := new(service)
241 s.typ = reflect.TypeOf(rcvr)
242 s.rcvr = reflect.ValueOf(rcvr)
243 sname := name
244 if !useName {
245 sname = reflect.Indirect(s.rcvr).Type().Name()
246 }
247 if sname == "" {
248 s := "rpc.Register: no service name for type " + s.typ.String()
249 log.Print(s)
250 return errors.New(s)
251 }
252 if !useName && !token.IsExported(sname) {
253 s := "rpc.Register: type " + sname + " is not exported"
254 log.Print(s)
255 return errors.New(s)
256 }
257 s.name = sname
258
259
260 s.method = suitableMethods(s.typ, logRegisterError)
261
262 if len(s.method) == 0 {
263 str := ""
264
265
266 method := suitableMethods(reflect.PointerTo(s.typ), false)
267 if len(method) != 0 {
268 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
269 } else {
270 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
271 }
272 log.Print(str)
273 return errors.New(str)
274 }
275
276 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
277 return errors.New("rpc: service already defined: " + sname)
278 }
279 return nil
280 }
281
282
283
284 func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
285 methods := make(map[string]*methodType)
286 for m := 0; m < typ.NumMethod(); m++ {
287 method := typ.Method(m)
288 mtype := method.Type
289 mname := method.Name
290
291 if !method.IsExported() {
292 continue
293 }
294
295 if mtype.NumIn() != 3 {
296 if logErr {
297 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
298 }
299 continue
300 }
301
302 argType := mtype.In(1)
303 if !isExportedOrBuiltinType(argType) {
304 if logErr {
305 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
306 }
307 continue
308 }
309
310 replyType := mtype.In(2)
311 if replyType.Kind() != reflect.Pointer {
312 if logErr {
313 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
314 }
315 continue
316 }
317
318 if !isExportedOrBuiltinType(replyType) {
319 if logErr {
320 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
321 }
322 continue
323 }
324
325 if mtype.NumOut() != 1 {
326 if logErr {
327 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
328 }
329 continue
330 }
331
332 if returnType := mtype.Out(0); returnType != typeOfError {
333 if logErr {
334 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
335 }
336 continue
337 }
338 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
339 }
340 return methods
341 }
342
343
344
345
346 var invalidRequest = struct{}{}
347
348 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
349 resp := server.getResponse()
350
351 resp.ServiceMethod = req.ServiceMethod
352 if errmsg != "" {
353 resp.Error = errmsg
354 reply = invalidRequest
355 }
356 resp.Seq = req.Seq
357 sending.Lock()
358 err := codec.WriteResponse(resp, reply)
359 if debugLog && err != nil {
360 log.Println("rpc: writing response:", err)
361 }
362 sending.Unlock()
363 server.freeResponse(resp)
364 }
365
366 func (m *methodType) NumCalls() (n uint) {
367 m.Lock()
368 n = m.numCalls
369 m.Unlock()
370 return n
371 }
372
373 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
374 if wg != nil {
375 defer wg.Done()
376 }
377 mtype.Lock()
378 mtype.numCalls++
379 mtype.Unlock()
380 function := mtype.method.Func
381
382 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
383
384 errInter := returnValues[0].Interface()
385 errmsg := ""
386 if errInter != nil {
387 errmsg = errInter.(error).Error()
388 }
389 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
390 server.freeRequest(req)
391 }
392
393 type gobServerCodec struct {
394 rwc io.ReadWriteCloser
395 dec *gob.Decoder
396 enc *gob.Encoder
397 encBuf *bufio.Writer
398 closed bool
399 }
400
401 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
402 return c.dec.Decode(r)
403 }
404
405 func (c *gobServerCodec) ReadRequestBody(body any) error {
406 return c.dec.Decode(body)
407 }
408
409 func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
410 if err = c.enc.Encode(r); err != nil {
411 if c.encBuf.Flush() == nil {
412
413
414 log.Println("rpc: gob error encoding response:", err)
415 c.Close()
416 }
417 return
418 }
419 if err = c.enc.Encode(body); err != nil {
420 if c.encBuf.Flush() == nil {
421
422
423 log.Println("rpc: gob error encoding body:", err)
424 c.Close()
425 }
426 return
427 }
428 return c.encBuf.Flush()
429 }
430
431 func (c *gobServerCodec) Close() error {
432 if c.closed {
433
434 return nil
435 }
436 c.closed = true
437 return c.rwc.Close()
438 }
439
440
441
442
443
444
445
446 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
447 buf := bufio.NewWriter(conn)
448 srv := &gobServerCodec{
449 rwc: conn,
450 dec: gob.NewDecoder(conn),
451 enc: gob.NewEncoder(buf),
452 encBuf: buf,
453 }
454 server.ServeCodec(srv)
455 }
456
457
458
459 func (server *Server) ServeCodec(codec ServerCodec) {
460 sending := new(sync.Mutex)
461 wg := new(sync.WaitGroup)
462 for {
463 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
464 if err != nil {
465 if debugLog && err != io.EOF {
466 log.Println("rpc:", err)
467 }
468 if !keepReading {
469 break
470 }
471
472 if req != nil {
473 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
474 server.freeRequest(req)
475 }
476 continue
477 }
478 wg.Add(1)
479 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
480 }
481
482
483 wg.Wait()
484 codec.Close()
485 }
486
487
488
489 func (server *Server) ServeRequest(codec ServerCodec) error {
490 sending := new(sync.Mutex)
491 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
492 if err != nil {
493 if !keepReading {
494 return err
495 }
496
497 if req != nil {
498 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
499 server.freeRequest(req)
500 }
501 return err
502 }
503 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
504 return nil
505 }
506
507 func (server *Server) getRequest() *Request {
508 server.reqLock.Lock()
509 req := server.freeReq
510 if req == nil {
511 req = new(Request)
512 } else {
513 server.freeReq = req.next
514 *req = Request{}
515 }
516 server.reqLock.Unlock()
517 return req
518 }
519
520 func (server *Server) freeRequest(req *Request) {
521 server.reqLock.Lock()
522 req.next = server.freeReq
523 server.freeReq = req
524 server.reqLock.Unlock()
525 }
526
527 func (server *Server) getResponse() *Response {
528 server.respLock.Lock()
529 resp := server.freeResp
530 if resp == nil {
531 resp = new(Response)
532 } else {
533 server.freeResp = resp.next
534 *resp = Response{}
535 }
536 server.respLock.Unlock()
537 return resp
538 }
539
540 func (server *Server) freeResponse(resp *Response) {
541 server.respLock.Lock()
542 resp.next = server.freeResp
543 server.freeResp = resp
544 server.respLock.Unlock()
545 }
546
547 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
548 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
549 if err != nil {
550 if !keepReading {
551 return
552 }
553
554 codec.ReadRequestBody(nil)
555 return
556 }
557
558
559 argIsValue := false
560 if mtype.ArgType.Kind() == reflect.Pointer {
561 argv = reflect.New(mtype.ArgType.Elem())
562 } else {
563 argv = reflect.New(mtype.ArgType)
564 argIsValue = true
565 }
566
567 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
568 return
569 }
570 if argIsValue {
571 argv = argv.Elem()
572 }
573
574 replyv = reflect.New(mtype.ReplyType.Elem())
575
576 switch mtype.ReplyType.Elem().Kind() {
577 case reflect.Map:
578 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
579 case reflect.Slice:
580 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
581 }
582 return
583 }
584
585 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
586
587 req = server.getRequest()
588 err = codec.ReadRequestHeader(req)
589 if err != nil {
590 req = nil
591 if err == io.EOF || err == io.ErrUnexpectedEOF {
592 return
593 }
594 err = errors.New("rpc: server cannot decode request: " + err.Error())
595 return
596 }
597
598
599
600 keepReading = true
601
602 dot := strings.LastIndex(req.ServiceMethod, ".")
603 if dot < 0 {
604 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
605 return
606 }
607 serviceName := req.ServiceMethod[:dot]
608 methodName := req.ServiceMethod[dot+1:]
609
610
611 svci, ok := server.serviceMap.Load(serviceName)
612 if !ok {
613 err = errors.New("rpc: can't find service " + req.ServiceMethod)
614 return
615 }
616 svc = svci.(*service)
617 mtype = svc.method[methodName]
618 if mtype == nil {
619 err = errors.New("rpc: can't find method " + req.ServiceMethod)
620 }
621 return
622 }
623
624
625
626
627
628 func (server *Server) Accept(lis net.Listener) {
629 for {
630 conn, err := lis.Accept()
631 if err != nil {
632 log.Print("rpc.Serve: accept:", err.Error())
633 return
634 }
635 go server.ServeConn(conn)
636 }
637 }
638
639
640 func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
641
642
643
644 func RegisterName(name string, rcvr any) error {
645 return DefaultServer.RegisterName(name, rcvr)
646 }
647
648
649
650
651
652
653
654
655
656 type ServerCodec interface {
657 ReadRequestHeader(*Request) error
658 ReadRequestBody(any) error
659 WriteResponse(*Response, any) error
660
661
662 Close() error
663 }
664
665
666
667
668
669
670
671 func ServeConn(conn io.ReadWriteCloser) {
672 DefaultServer.ServeConn(conn)
673 }
674
675
676
677 func ServeCodec(codec ServerCodec) {
678 DefaultServer.ServeCodec(codec)
679 }
680
681
682
683 func ServeRequest(codec ServerCodec) error {
684 return DefaultServer.ServeRequest(codec)
685 }
686
687
688
689
690 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
691
692
693 var connected = "200 Connected to Go RPC"
694
695
696 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
697 if req.Method != "CONNECT" {
698 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
699 w.WriteHeader(http.StatusMethodNotAllowed)
700 io.WriteString(w, "405 must CONNECT\n")
701 return
702 }
703 conn, _, err := w.(http.Hijacker).Hijack()
704 if err != nil {
705 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
706 return
707 }
708 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
709 server.ServeConn(conn)
710 }
711
712
713
714
715 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
716 http.Handle(rpcPath, server)
717 http.Handle(debugPath, debugHTTP{server})
718 }
719
720
721
722
723 func HandleHTTP() {
724 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
725 }
726
View as plain text