/** * Created by IntelliJ IDEA. * User: Ar.M * Date: 2020-05-10 * Time: 22:13 */ package main import ( "context" "crypto/sha256" "encoding/hex" "fmt" "google.golang.org/grpc" "io" "log" "net" "os" "path/filepath" transportpb "yunlian-file-v2/proto" ) var filesDir string func init() { str, _ := os.Getwd() filesDir = filepath.Join(str, "/files") } type server struct{} func (s *server) GetFileInfo(ctx context.Context, req *transportpb.FileInfoRequest) (*transportpb.FileInfoResponse, error) { fileName := req.GetFileName() path := filepath.Join(filesDir, fileName) file, err := os.Open(path) defer file.Close() if err != nil { return nil, err } hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { fmt.Println(err) } fileSha256 := hex.EncodeToString(hash.Sum(nil)) fileInfo, err := file.Stat() if err != nil { return nil, err } resp := &transportpb.FileInfoResponse{ FileName: fileName, FileSize: fileInfo.Size(), FileSha256: fileSha256, } return resp, nil } func (s *server) Download(req *transportpb.FileRequest, stream transportpb.FileService_DownloadServer) error { fileName := req.GetFileName() fileRangeStart := req.GetFileRangeStart() fileRangeEnd := req.GetFileRangeEnd() path := filepath.Join(filesDir, fileName) fileInfo, err := os.Stat(path) if err != nil { return err } fileSize := fileInfo.Size() f, err := os.Open(path) if err != nil { return err } defer f.Close() _, err = f.Seek(fileRangeStart, 0) if err != nil { return err } log.Printf("开始读取区块: %d-%d\n", fileRangeStart, fileRangeEnd) var totalBytesStreamed int64 var packageLen int64 = 1024 for totalBytesStreamed < fileRangeEnd-fileRangeStart { block := make([]byte, packageLen) if totalBytesStreamed+packageLen > fileRangeEnd-fileRangeStart { block = make([]byte, fileRangeEnd-fileRangeStart-totalBytesStreamed) } bytesRead, err := f.Read(block) if err == io.EOF { break } if err != nil { return err } if err := stream.Send(&transportpb.FileResponse{ Block: block, FileSize: fileSize, }); err != nil { return err } totalBytesStreamed += int64(bytesRead) } return nil } func main() { lis, err := net.Listen("tcp", "0.0.0.0:10086") if err != nil { log.Fatalf("Failed to listen on 10086 : %v\n", err) } s := grpc.NewServer() transportpb.RegisterFileServiceServer(s, &server{}) fmt.Println("Starting server on 10086") if err := s.Serve(lis); err != nil { log.Fatalf("failed to start server : %v\n", err) } }