def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
datas = serializer.data
datas = get_count_by_project(datas)
return self.get_paginated_response(datas)
serializer = self.get_serializer(queryset, many=True)
datas = serializer.data
datas = get_count_by_project(datas)
return Response(datas)
这个 list
其实就是拷贝了父类中的 list
方法
使用
super().list(request, *args, **kwargs)
调用父类的 list
方法
查看返回的 Response
对象
所以优化为:
def list(self, request, *args, **kwargs):
response = super().list(request, *args, **kwargs)
response.data['results'] = get_count_by_project(response.data['results'])
return response
names
中的 serializer
使用 serializers.ProjectNameSerializer
为了让它可以直接使用 self.get_serializer
方法,重写 get_serializer_class
def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
)
return self.serializer_class
def get_serializer_class(self):
if self.action == 'names':
return serializers.ProjectNameSerializer
else:
return self.serializer_class
from datetime import datetime
from rest_framework import serializers
from .models import Reports
class ReportsSerializer(serializers.ModelSerializer):
"""
报告序列化器
"""
class Meta:
model = Reports
exclude = ('update_time', 'is_delete')
extra_kwargs = {
'html': {
'write_only': True
},
'create_time': {
'read_only': True
}
}
def create(self, validated_data):
report_name = validated_data['name']
validated_data['name'] = f"{report_name}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}"
report = Reports.objects.create(**validated_data)
return report
从数据库中可以看出其中 html
是一串字符串,它需要转换成html格式才可以正常展示,所以在接口返回的内容中不应该包含它,设置它为只写模式 write_only
对 create
函数进行重定义
下面是数据库中显示的内容
name = models.CharField('报告名称', max_length=200, unique=True, help_text='报告名称')
查看 models
文件可以看到 name
字段是唯一的「 unique=True
」所以我们在添加的时候需要携带上当前的时间信息
定义一个类 ReportsViewSet
还是继承 ModelViewSet
其他和之前的类似
其中要注意的是一个 download
接口
import re
import os
from datetime import datetime
from django.conf import settings
from django.http import StreamingHttpResponse
from rest_framework.viewsets import ModelViewSet
from rest_framework import permissions
from rest_framework.decorators import action
from reports.utils import format_output, get_file_contents
from .models import Reports
from .serializers import ReportsSerializer
class ReportsViewSet(ModelViewSet):
"""
list:
返回测试报告(多个)列表数据
create:
创建测试报告
retrieve:
返回测试报告(单个)详情数据
update:
更新(全)测试报告
partial_update:
更新(部分)测试报告
destroy:
删除测试报告
"""
queryset = Reports.objects.filter(is_delete=False)
serializer_class = ReportsSerializer
permission_classes = (permissions.IsAuthenticated,)
ordering_fields = ('id', 'name')
def perform_destroy(self, instance):
instance.is_delete = True
instance.save()
def list(self, request, *args, **kwargs):
response = super().list(request, *args, **kwargs)
response.data['results'] = format_output(response.data['results'])
return response
@action(detail=True)
def download(self, request, pk=None):
instance = self.get_object()
html = instance.html
name = instance.name
mtch = re.match(r'(.*_)\d+', name)
if mtch:
mtch = mtch.group(1) + datetime.strftime(datetime.now(), '%Y%m%d%H%M%S') + '.html'
report_dir = os.path.join(settings.BASE_DIR, 'reports')
report_path = os.path.join(report_dir, mtch)
with open(report_path, 'w') as f:
f.write(html)
response = StreamingHttpResponse(get_file_contents(report_path))
response['Content-Type'] = "application/octet-stream"
response['Content-Disposition'] = "attachment; filename*=UTF-8''{}".format(name)
return response
每次下载之后我们都会在本地存放一次,然后我们需要以数据流的方式返回html报告
response = StreamingHttpResponse(get_file_contents(report_path))
def get_file_contents(filename, chunk_size=512):
with open(filename,encoding='utf8') as f:
while True:
c = f.read(chunk_size)
if c:
yield c
else:
break
这里用到了分段的方式,每512字节返回一次,直到全部返回完毕